Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2c3cb9fc
Commit
2c3cb9fc
authored
Mar 24, 2022
by
Lawrence McAfee
Browse files
many edits; working towards first draft.
parent
867105c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
456 additions
and
271 deletions
+456
-271
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+426
-259
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+30
-12
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
2c3cb9fc
...
...
@@ -26,116 +26,122 @@ from megatron.model.module import param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
# from .optimizer import Float16OptimizerWithFloat16Params
# >>>
from
lutil
import
pax
,
tp
,
print_seq
# <<<
class
Shard
:
# >>>
# class Shard:
class
Range
:
# <<<
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
end
=
end
self
.
size
=
end
-
start
def
normalize
(
self
,
start
=
0
):
return
Shard
(
start
,
start
+
self
.
size
)
return
Range
(
start
,
start
+
self
.
size
)
def
__str__
(
self
):
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
# class DistributedOptimizer(Float16OptimizerWithFloat16Params):
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
@
classmethod
def
get
_model_gbuf_param_
shard
_map
(
cls
,
model
,
dtype
,
gbuf_world_
shard
):
def
build
_model_gbuf_param_
range
_map
(
cls
,
model
,
dtype
,
gbuf_world_
range
):
# Param
shard
map.
# Param
range
map.
param_world_index_map
=
model
.
_grad_buffer_param_index_map
[
dtype
]
param_
shard
_map
=
{}
param_
range
_map
=
{}
for
param
,
param_world_indexes
in
param_world_index_map
.
items
():
#
Shard
range.
#
Param
range.
param_world_start
,
param_world_end
=
param_world_indexes
param_local_start
=
max
(
0
,
param_world_start
-
gbuf_world_
shard
.
start
)
param_world_start
-
gbuf_world_
range
.
start
)
param_local_end
=
min
(
gbuf_world_
shard
.
size
,
param_world_end
-
gbuf_world_
shard
.
start
)
gbuf_world_
range
.
size
,
param_world_end
-
gbuf_world_
range
.
start
)
# Add
shard
, if within range.
# Add
param
, if within
local gbuf
range.
if
param_local_end
>
param_local_start
:
param_local_
shard
=
Shard
(
param_local_start
,
param_local_end
)
param_world_
shard
=
param_local_
shard
.
normalize
(
param_local_start
+
gbuf_world_
shard
.
start
)
sub_param_start
=
max
(
0
,
gbuf_world_
shard
.
start
-
param_world_start
)
sub_param_
shard
=
param_local_
shard
.
normalize
(
sub_param_start
)
param_
shard
_map
[
param
]
=
{
"gbuf_world"
:
param_world_
shard
,
"gbuf_local"
:
param_local_
shard
,
"param"
:
sub_param_
shard
,
param_local_
range
=
Range
(
param_local_start
,
param_local_end
)
param_world_
range
=
param_local_
range
.
normalize
(
param_local_start
+
gbuf_world_
range
.
start
)
sub_param_start
=
max
(
0
,
gbuf_world_
range
.
start
-
param_world_start
)
sub_param_
range
=
param_local_
range
.
normalize
(
sub_param_start
)
param_
range
_map
[
param
]
=
{
"gbuf_world"
:
param_world_
range
,
"gbuf_local"
:
param_local_
range
,
"param"
:
sub_param_
range
,
}
return
param_
shard
_map
return
param_
range
_map
@
classmethod
def
get
_model_gbuf_
shard
(
cls
,
model
,
dtype
):
def
build
_model_gbuf_
range
(
cls
,
model
,
dtype
):
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Grad buffer
shard
.
# Grad buffer
range
.
grad_buffer
=
model
.
_grad_buffers
[
dtype
]
gbuf_size
=
grad_buffer
.
numel
max_gbuf_
shard
_size
=
int
(
math
.
ceil
(
gbuf_size
/
data_parallel_world_size
))
max_gbuf_
range
_size
=
int
(
math
.
ceil
(
gbuf_size
/
data_parallel_world_size
))
# All world
shard
s. (i.e., across all data parallel ranks)
gbuf_world_all_
shard
s
=
[]
# All world
range
s. (i.e., across all data parallel ranks)
gbuf_world_all_
range
s
=
[]
for
r
in
range
(
data_parallel_world_size
):
gbuf_world_start
=
r
*
max_gbuf_
shard
_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_
shard
_size
)
gbuf_world_
shard
=
Shard
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_world_all_
shard
s
.
append
(
gbuf_world_
shard
)
gbuf_world_start
=
r
*
max_gbuf_
range
_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_
range
_size
)
gbuf_world_
range
=
Range
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_world_all_
range
s
.
append
(
gbuf_world_
range
)
# Local DP's
shard
s.
gbuf_world_
shard
=
gbuf_world_all_
shard
s
[
data_parallel_rank
]
gbuf_local_
shard
=
gbuf_world_
shard
.
normalize
()
# Local DP's
range
s.
gbuf_world_
range
=
gbuf_world_all_
range
s
[
data_parallel_rank
]
gbuf_local_
range
=
gbuf_world_
range
.
normalize
()
# Get each param's
shard
s.
param_
shard
_map
=
cls
.
get
_model_gbuf_param_
shard
_map
(
model
,
dtype
,
gbuf_world_
shard
)
# Get each param's
range
s.
param_
range
_map
=
cls
.
build
_model_gbuf_param_
range
_map
(
model
,
dtype
,
gbuf_world_
range
)
# Altogether.
data
=
{
"local"
:
gbuf_local_
shard
,
"world"
:
gbuf_world_
shard
,
"world_all"
:
gbuf_world_all_
shard
s
,
"param_map"
:
param_
shard
_map
,
"max_
shard
_size"
:
max_gbuf_
shard
_size
,
"local"
:
gbuf_local_
range
,
"world"
:
gbuf_world_
range
,
"world_all"
:
gbuf_world_all_
range
s
,
"param_map"
:
param_
range
_map
,
"max_
range
_size"
:
max_gbuf_
range
_size
,
}
return
data
@
classmethod
def
get
_model_gbuf_
shard
_map
(
cls
,
model
):
def
build
_model_gbuf_
range
_map
(
cls
,
model
):
return
{
dtype
:
cls
.
get
_model_gbuf_
shard
(
model
,
dtype
)
dtype
:
cls
.
build
_model_gbuf_
range
(
model
,
dtype
)
for
dtype
in
model
.
_grad_buffers
}
@
classmethod
def
get
_param_gbuf_map
(
cls
,
model_gbuf_
shard
s
):
'''Create a reverse of the model_gbuf_
shard
s, for referencing in
def
build_model
_param_gbuf_map
(
cls
,
model_gbuf_
range
s
):
'''Create a reverse of the model_gbuf_
range
s, for referencing in
opposite direction.'''
param_gbuf_map
=
{}
for
model_index
,
model_gbuf_
shard
_map
in
enumerate
(
model_gbuf_
shard
s
):
for
dtype
,
gbuf_
shard
_map
in
model_gbuf_
shard
_map
.
items
():
for
param
,
param_
shard
_map
in
gbuf_
shard
_map
[
"param_map"
].
items
():
for
model_index
,
model_gbuf_
range
_map
in
enumerate
(
model_gbuf_
range
s
):
for
dtype
,
gbuf_
range
_map
in
model_gbuf_
range
_map
.
items
():
for
param
,
param_
range
_map
in
gbuf_
range
_map
[
"param_map"
].
items
():
param_gbuf_map
[
param
]
=
(
model_index
,
dtype
)
return
param_gbuf_map
# >>>
# @classmethod
# def
get
_optimizer_group_
shard
s(cls, param_groups, model_gbuf_
shard
s):
# def
build
_optimizer_group_
range
s(cls, param_groups, model_gbuf_
range
s):
# num_groups = len(param_groups)
...
...
@@ -146,31 +152,31 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# assert param.requires_grad
# param_group_map[param] = group_index
# # Optimizer group
shard
s.
# group_
shard
s = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_
shard
_map in model_gbuf_
shard
s:
# for dtype, gbuf_
shard
_map in model_gbuf_
shard
_map.items():
# for param in gbuf_
shard
_map["param_map"]:
# # Optimizer group
range
s.
# group_
range
s = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_
range
_map in model_gbuf_
range
s:
# for dtype, gbuf_
range
_map in model_gbuf_
range
_map.items():
# for param in gbuf_
range
_map["param_map"]:
# group_index = param_group_map[param]
# group_
shard
= group_
shard
s[group_index]
# param_size = gbuf_
shard
_map["param_map"][param]["param"].size
# group_
range
= group_
range
s[group_index]
# param_size = gbuf_
range
_map["param_map"][param]["param"].size
# param_group_start = group_
shard
["size"]
# param_group_start = group_
range
["size"]
# param_group_end = param_group_start + param_size
# param_group_
shard = Shard
(param_group_start, param_group_end)
# param_group_
range = Range
(param_group_start, param_group_end)
# group_
shard
["size"] += param_size
# group_
shard
["param_map"][param] = param_group_
shard
# group_
range
["size"] += param_size
# group_
range
["param_map"][param] = param_group_
range
# # Squeeze zero-size group
shard
s.
# for group_index, group_
shard
in enumerate(group_
shard
s):
# group_
shard
["orig_group"] = param_groups[group_index]
# group_
shard
s = [ g for g in group_
shard
s if g["size"] > 0 ]
# # Squeeze zero-size group
range
s.
# for group_index, group_
range
in enumerate(group_
range
s):
# group_
range
["orig_group"] = param_groups[group_index]
# group_
range
s = [ g for g in group_
range
s if g["size"] > 0 ]
# return group_
shard
s
# return group_
range
s
@
classmethod
def
get
_optimizer_group_
shard
s
(
cls
,
param_groups
,
model_gbuf_
shard
s
):
def
build
_optimizer_group_
range
s
(
cls
,
param_groups
,
model_gbuf_
range
s
):
num_groups
=
len
(
param_groups
)
...
...
@@ -181,35 +187,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert
param
.
requires_grad
param_group_map
[
param
]
=
group_index
# Optimizer group
shard
s.
# Optimizer group
range
s.
# >>>
# group_
shard
s = [ {"size": 0, "param_map": {}} for _ in param_groups ]
group_
shard
s
=
[
{
"params"
:
[]}
for
_
in
param_groups
]
# group_
shard
s = [ [] for _ in param_groups ]
# group_
range
s = [ {"size": 0, "param_map": {}} for _ in param_groups ]
group_
range
s
=
[
{
"params"
:
[]}
for
_
in
param_groups
]
# group_
range
s = [ [] for _ in param_groups ]
# <<<
for
model_gbuf_
shard
_map
in
model_gbuf_
shard
s
:
for
dtype
,
gbuf_
shard
_map
in
model_gbuf_
shard
_map
.
items
():
for
param
in
gbuf_
shard
_map
[
"param_map"
]:
for
model_gbuf_
range
_map
in
model_gbuf_
range
s
:
for
dtype
,
gbuf_
range
_map
in
model_gbuf_
range
_map
.
items
():
for
param
in
gbuf_
range
_map
[
"param_map"
]:
group_index
=
param_group_map
[
param
]
group_
shard
=
group_
shard
s
[
group_index
]
group_
shard
[
"params"
].
append
(
param
)
group_
range
=
group_
range
s
[
group_index
]
group_
range
[
"params"
].
append
(
param
)
# Squeeze zero-size group
shard
s.
for
group_index
,
group_
shard
in
enumerate
(
group_
shard
s
):
group_
shard
[
"orig_group"
]
=
param_groups
[
group_index
]
group_
shard
s
=
[
g
for
g
in
group_
shard
s
if
len
(
g
[
"params"
])
>
0
]
# Squeeze zero-size group
range
s.
for
group_index
,
group_
range
in
enumerate
(
group_
range
s
):
group_
range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_
range
s
=
[
g
for
g
in
group_
range
s
if
len
(
g
[
"params"
])
>
0
]
# >>>
# print_seq("group
shard
s / len = %s." %
# ", ".join(str(len(s["params"])) for s in group_
shard
s))
# print_seq("group
range
s / len = %s." %
# ", ".join(str(len(s["params"])) for s in group_
range
s))
# <<<
return
group_
shard
s
return
group_
range
s
# <<<
# >>>
# @classmethod
# def allocate_main_param_shards(cls, opt_group_
shard
s):
# def allocate_main_param_shards(cls, opt_group_
range
s):
# # Allocator method.
# allocate_shard = lambda shard_size, dtype : torch.empty(
...
...
@@ -219,9 +225,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# requires_grad = True)
# # Allocate each group's param/grad shard.
# for group_index, group_
shard
in enumerate(opt_group_
shard
s):
# for group_index, group_
range
in enumerate(opt_group_
range
s):
# group_size = group_
shard
["size"]
# group_size = group_
range
["size"]
# assert group_size != 0, "temporary check ... remove me."
# # Allocate shard.
...
...
@@ -230,71 +236,74 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_
shard
["orig_group"]["params"] = [ main_param ]
# group_
range
["orig_group"]["params"] = [ main_param ]
@
classmethod
# def allocate_main_params(cls, opt_group_shards):
def
allocate_or_view_main_param_shards
(
cls
,
model_gbuf_shards
,
param_gbuf_map
,
opt_group_shards
):
# # Allocator method.
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# Allocate each group's param/grad shard.
for
group_index
,
group_shard
in
enumerate
(
opt_group_shards
):
# group_size = group_shard["size"]
# assert group_size != 0, "temporary check ... remove me."
# # Allocate shard.
# main_param = allocate_shard(group_size, torch.float)
# main_param.grad = allocate_shard(group_size, torch.float)
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_shard["orig_group"]["params"] = [ main_param ]
group_main_params
=
[]
group_shard
[
"orig_group"
][
"params"
]
=
group_main_params
for
param
in
group_shard
[
"params"
]:
model_index
,
dtype
=
param_gbuf_map
[
param
]
gbuf_shard
=
model_gbuf_shards
[
model_index
][
dtype
]
param_shard
=
gbuf_shard
[
"param_map"
][
param
][
"param"
]
pax
(
0
,
{
"model_index"
:
model_index
,
"dtype"
:
dtype
,
"gbuf_shard"
:
gbuf_shard
,
"param_shard"
:
param_shard
,
})
# def allocate_main_params(cls, opt_group_ranges):
# def allocate_or_view_main_param_shards(cls,
def
build_model_and_main_param_groups
(
cls
,
model_gbuf_ranges
,
param_gbuf_map
,
opt_group_ranges
):
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_groups: original fp32 parameters
full_float16_groups
=
[]
full_fp32_groups
=
[]
shard_float16_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_from_float16_groups
=
[]
# Allocate each group's param shard.
for
group_index
,
group_range
in
enumerate
(
opt_group_ranges
):
# Params of this group.
full_float16_params_this_group
=
[]
full_fp32_params_this_group
=
[]
shard_float16_params_this_group
=
[]
shard_fp32_params_this_group
=
[]
shard_fp32_from_float16_params_this_group
=
[]
full_float16_groups
.
append
(
full_float16_params_this_group
)
full_fp32_groups
.
append
(
full_fp32_params_this_group
)
shard_float16_groups
.
append
(
shard_float16_params_this_group
)
shard_fp32_groups
.
append
(
shard_fp32_params_this_group
)
shard_fp32_from_float16_groups
.
append
(
shard_fp32_from_float16_params_this_group
)
for
model_param
in
group_range
[
"params"
]:
model_index
,
dtype
=
param_gbuf_map
[
model_param
]
gbuf_range
=
model_gbuf_ranges
[
model_index
][
dtype
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
# fp16, bf16 params.
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
# Allocate/copy main param/grad.
main_param
=
param
.
detach
()[
param_shard
.
start
:
param_shard
.
end
].
clone
().
float
()
if
accumulate_allreduce_grads_in_fp32
:
main_param
.
grad
=
param
.
main_grad
[
param_shard
.
start
:
param_shard
.
end
]
else
:
main_param
.
grad
=
param
.
main_grad
.
detach
()[
param_shard
.
start
:
param_shard
.
end
].
clone
().
float
()
# Copy tensor model parallel attributes.
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
if
hasattr
(
param
,
'shared'
):
main_param
.
shared
=
param
.
shared
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
# Clone model -> main.
shard_model_param
=
\
model_param
.
detach
()[
param_range
.
start
:
param_range
.
end
]
shard_main_param
=
shard_model_param
.
clone
().
float
()
mpu
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
shard_main_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# Add to group.
full_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_fp32_from_float16_params_this_group
.
append
(
shard_main_param
)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
main_param
=
param
main_param
.
grad
=
param
.
main_grad
shard_model_param
=
\
model_param
[
param_range
.
start
:
param_range
.
end
]
full_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
...
...
@@ -303,23 +312,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
# Add to group.
group_main_params
.
append
(
main_param
)
# # Add to group.
# group_main_params.append(main_param)
group_range
[
"orig_group"
][
"params"
]
=
[
*
shard_fp32_params_this_group
,
*
shard_fp32_from_float16_params_this_group
,
]
return
(
full_float16_groups
,
full_fp32_groups
,
shard_float16_groups
,
shard_fp32_groups
,
shard_fp32_from_float16_groups
,
)
# <<<
# >>>
# @classmethod
# def
get
_main_grad_views_for_grad_norm(cls, opt_group_
shard
s, optimizer):
# def
build
_main_grad_views_for_grad_norm(cls, opt_group_
range
s, optimizer):
# grad_views = []
# for group_index, opt_group_
shard
in enumerate(opt_group_
shard
s):
# for group_index, opt_group_
range
in enumerate(opt_group_
range
s):
# opt_grad = optimizer.param_groups[group_index]["params"][0].grad
# for param,
shard
in opt_group_
shard
["param_map"].items():
# for param,
range
in opt_group_
range
["param_map"].items():
# if param_is_not_shared(param) and \
# param_is_not_tensor_parallel_duplicate(param):
# grad_view = opt_grad[
shard
.start:
shard
.end]
# grad_view = opt_grad[
range
.start:
range
.end]
# grad_views.append(grad_view)
# return grad_views
...
...
@@ -342,108 +363,162 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert
use_contiguous_buffers_in_local_ddp
# <<<
# Model grad buffer
shard
s.
self
.
model_gbuf_
shard
s
=
[]
# Model grad buffer
range
s.
self
.
model_gbuf_
range
s
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
self
.
model_gbuf_shards
.
append
(
self
.
get_model_gbuf_shard_map
(
model
))
self
.
param_gbuf_map
=
self
.
get_param_gbuf_map
(
self
.
model_gbuf_shards
)
self
.
model_gbuf_ranges
.
append
(
self
.
build_model_gbuf_range_map
(
model
))
self
.
model_param_gbuf_map
=
\
self
.
build_model_param_gbuf_map
(
self
.
model_gbuf_ranges
)
# Optimizer
shard
s.
self
.
opt_group_
shard
s
=
self
.
get
_optimizer_group_
shard
s
(
# Optimizer
range
s.
self
.
opt_group_
range
s
=
self
.
build
_optimizer_group_
range
s
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_
shard
s
)
self
.
model_gbuf_
range
s
)
# Allocate main param shards.
self
.
allocate_or_view_main_param_shards
(
self
.
model_gbuf_shards
,
self
.
param_gbuf_map
,
self
.
opt_group_shards
)
(
self
.
full_float16_groups
,
self
.
full_fp32_groups
,
self
.
shard_float16_groups
,
self
.
shard_fp32_groups
,
self
.
shard_fp32_from_float16_groups
,
)
=
self
.
build_model_and_main_param_groups
(
self
.
model_gbuf_ranges
,
self
.
model_param_gbuf_map
,
self
.
opt_group_ranges
)
# print_seq("16 [%d], 16x32 [%d], 32 [%d]." % (
# sum(len(g) for g in self.float16_groups),
# sum(len(g) for g in self.fp32_from_float16_groups),
# sum(len(g) for g in self.fp32_groups),
# ))
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self
.
optimizer
.
param_groups
=
\
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_
shard
s
]
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_
range
s
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
# >>>
# # Initialize main params.
# self._copy_model_params_to_main_params()
# <<<
# >>>
# # Params for grad norm.
# self.main_grad_views_for_grad_norm = self.
get
_main_grad_views_for_grad_norm(
# self.opt_group_
shard
s,
# self.main_grad_views_for_grad_norm = self.
build
_main_grad_views_for_grad_norm(
# self.opt_group_
range
s,
# self.optimizer)
# <<<
def
get_model_param_range_map
(
self
,
param
):
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
# >>>
# pax(0, {
# "param" : param,
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf_range_map" : gbuf_range_map,
# "param_range_map" : param_range_map,
# })
# <<<
return
param_range_map
def
get_model_parallel_group
(
self
):
return
None
def
get_main_params
(
self
):
return
[
g
[
"params"
][
0
]
for
g
in
self
.
optimizer
.
param_groups
]
def
get_main_grads
(
self
):
return
[
p
.
grad
for
p
in
self
.
get_main_params
()
]
def
get_main_param
(
self
,
group_index
):
return
self
.
get_main_params
()[
group_index
]
def
get_main_grad
(
self
,
group_index
):
return
self
.
get_main_param
(
group_index
).
grad
#
def get_main_params(self):
#
return [ g["params"][0] for g in self.optimizer.param_groups ]
#
def get_main_grads(self):
#
return [ p.grad for p in self.get_main_params() ]
#
def get_main_param(self, group_index):
#
return self.get_main_params()[group_index]
#
def get_main_grad(self, group_index):
#
return self.get_main_param(group_index).grad
# >>>
# def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
def
get_main_grads_for_grad_norm
(
self
):
raise
Exception
(
"does 'super' work?"
)
# <<<
# def state_dict(self):
# state_dict = {}
# state_dict['optimizer'] = self.optimizer.state_dict()
# if self.grad_scaler:
# state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# return state_dict
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'groups'
]
=
[
g
[
'params'
]
for
g
in
self
.
optimizer
.
param_groups
]
return
state_dict
raise
Exception
(
"fix me."
)
# def load_state_dict(self, state_dict):
# # Optimizer.
# optimizer_key = 'optimizer'
# if optimizer_key not in state_dict:
# optimizer_key = 'optimizer_state_dict'
# print_rank_0('***WARNING*** loading optimizer from '
# 'an old checkpoint ...')
# self.optimizer.load_state_dict(state_dict[optimizer_key])
# # Grad scaler.
# if 'grad_scaler' not in state_dict:
# print_rank_0('***WARNING*** found an old checkpoint, will not '
# 'load grad scaler ...')
# else:
# if self.grad_scaler:
# self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# else:
# print_rank_0('***WARNING*** fould the grad scaler in the '
# 'checkpoint but it is None in the class. '
# 'Skipping loading grad scaler ...')
# # Copy data for the main params.
# current_groups = [ g["params"] for g in self.optimizer.param_groups ]
# assert "groups" in state_dict, "key 'groups' not in state_dict."
# for current_group, saved_group in zip(current_groups, state_dict["groups"]):
# for current_param, saved_param in zip(current_group, saved_group):
# current_param.data.copy_(saved_param.data)
def
load_state_dict
(
self
,
state_dict
):
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
current_groups
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
assert
"groups"
in
state_dict
,
"key 'groups' not in state_dict."
for
current_group
,
saved_group
in
zip
(
current_groups
,
state_dict
[
"groups"
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
raise
Exception
(
"hi."
)
# def zero_grad(self, set_to_none=True):
def
zero_grad
(
self
,
set_to_none
=
True
):
# Collect model params.
model_params
=
[]
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
model_params
.
extend
(
param_map
.
keys
())
# # Collect model params.
# model_params = []
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# model_params.extend(param_map.keys())
# Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# # Distributed optimizer requires contiguous buffer; don't set to None.
# _zero_grad_group_helper(model_params, set_to_none = False)
# def zero_grad(self, set_to_none=True):
# raise Exception("does 'super' work?")
# >>>
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for
groups
in
(
self
.
full_float16_groups
,
self
.
full_fp32_groups
,
self
.
shard_fp32_from_float16_groups
):
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
# <<<
def
get_model_grad_buffer_dp_views
(
self
):
...
...
@@ -469,6 +544,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
grads.
'''
# >>>
# print_seq([
# tp(b.data)
# for m in self.models
# for b in m._grad_buffers.values()
# ])
# <<<
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
(
args
)
...
...
@@ -498,6 +581,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
gather_model_params
(
self
,
args
,
timers
):
raise
Exception
(
"hi."
)
timers
(
'backward-params-all-gather'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
...
...
@@ -526,69 +611,151 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'backward-params-all-gather'
).
stop
()
def
_collect_main_grad_data_for_unscaling
(
self
):
raise
Exception
(
"hi."
)
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
def
_copy_model_params_to_main_params
(
self
):
# >>>
# def _copy_model_params_to_main_params(self):
for
group_index
,
group_
shard
in
enumerate
(
self
.
opt_group_
shard
s
):
main_param
=
self
.
get_main_param
(
group_index
)
for
model_param
,
main_
shard
in
group_
shard
[
"param_map"
].
items
():
#
for group_index, group_
range
in enumerate(self.opt_group_
range
s):
#
main_param = self.get_main_param(group_index)
#
for model_param, main_
range
in group_
range
["param_map"].items():
# Model shard.
model_index
,
dtype
=
self
.
param_gbuf_map
[
model_param
]
model_shard
=
self
.
model_gbuf_shards
\
[
model_index
][
dtype
][
"param_map"
][
model_param
][
"param"
]
# # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["param"]
# model_range = self.get_model_param_range_map(model_param)["param"]
assert
main_
shard
.
size
==
model_
shard
.
size
#
assert main_
range
.size == model_
range
.size
# Copy shard data.
main_view
=
main_param
[
main_
shard
.
start
:
main_
shard
.
end
]
model_view
=
model_param
.
view
(
-
1
)[
model_
shard
.
start
:
model_
shard
.
end
]
#
# Copy shard data.
#
main_view = main_param[main_
range
.start:main_
range
.end]
#
model_view = model_param.view(-1)[model_
range
.start:model_
range
.end]
main_view
.
detach
().
copy_
(
model_view
)
# main_view.detach().copy_(model_view)
def
_copy_model_params_to_main_params
(
self
):
raise
Exception
(
"check if super's copy works."
)
# <<<
# >>>
# def _copy_model_grads_to_main_grads(self):
def
_copy_model_grads_to_main_grads
(
self
):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# for model_param, main_range in group_range["param_map"].items():
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
# # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# Model shard.
model_index
,
dtype
=
self
.
param_gbuf_map
[
model_param
]
model_shard
=
self
.
model_gbuf_shards
\
[
model_index
][
dtype
][
"param_map"
][
model_param
][
"gbuf_world"
]
# assert main_range.size == model_range.size
assert
main_shard
.
size
==
model_shard
.
size
# # Copy from DDP's contiguous buffer to main shard's grad.
# model_grad = self.models[model_index]._grad_buffers[dtype].data
# main_grad = self.get_main_grad(group_index)
# Copy
from DDP's contiguous buffer to main shard's grad
.
model_
grad
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
main_grad
=
self
.
get_
main_
g
ra
d
(
group_index
)
#
# Copy
sub-range within tensor
.
#
model_
view = model_grad[model_range.start:model_range.end]
#
main_view = main_grad[main_range.start:
main_ra
nge.end]
# Copy sub-range within tensor.
model_view
=
model_grad
[
model_shard
.
start
:
model_shard
.
end
]
main_view
=
main_grad
[
main_shard
.
start
:
main_shard
.
end
]
# main_view.detach().copy_(model_view)
# def _copy_model_grads_to_main_grads(self):
# super()._copy_model_grads_to_main_grads()
# raise Exception("check main param '.grad'.")
main_view
.
detach
().
copy_
(
model_view
)
# for group in self.optimizer.param_groups:
# for param in group["params"]:
# param.grad =
def
_copy_model_grads_to_main_grads
(
self
):
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
def
_copy_main_params_to_model_params
(
self
):
# This only needs to be done for the float16 group.
for
full_model_group
,
shard_main_group
in
zip
(
self
.
full_float16_groups
,
self
.
shard_fp32_from_float16_groups
):
for
full_model_param
,
shard_main_param
in
zip
(
full_model_group
,
shard_main_group
):
param_range_map
=
self
.
get_model_param_range_map
(
full_model_param
)
param_range
=
param_range_map
[
"param"
]
full_model_grad
=
full_model_param
.
main_grad
shard_model_grad
=
\
full_model_grad
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# >>>
if
full_model_param
.
nelement
()
!=
shard_main_param
.
nelement
():
pax
(
0
,
{
"param_range_map"
:
param_range_map
,
"param_range"
:
param_range
,
"full_model_param"
:
tp
(
full_model_param
),
"full_model_grad"
:
tp
(
full_model_grad
),
"shard_model_grad"
:
tp
(
shard_model_grad
),
"shard_main_grad"
:
tp
(
shard_main_param
.
grad
),
"shard_main_param"
:
tp
(
shard_main_param
),
})
# <<<
# For fp32 grads, we need to reset the grads to main grad.
for
group
in
self
.
fp32_groups
:
for
param
in
group
:
param
.
grad
=
param
.
main_grad
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
# >>>
print_seq
([
"grad = %s."
%
tp
(
p
.
grad
)
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
])
# <<<
model_index
,
dtype
=
self
.
param_gbuf_map
[
model_param
]
model_shard
=
self
.
model_gbuf_shards
\
[
model_index
][
dtype
][
"param_map"
][
model_param
][
"gbuf_world"
]
# <<<
# >>>
# def _copy_main_params_to_model_params(self):
assert
main_shard
.
size
==
model_shard
.
size
# for group_index, group_range in enumerate(self.opt_group_ranges):
# for model_param, main_range in group_range["param_map"].items():
# Use DDP's contiguous buffer to temporarily hold params.
model_param
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
main_param
=
self
.
get_main_param
(
group_index
)
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# Copy sub-range within tensor.
model_view
=
model_param
[
model_shard
.
start
:
model_shard
.
end
]
main_view
=
main_param
[
main_shard
.
start
:
main_shard
.
end
]
# assert main_range.size == model_range.size
model_view
.
detach
().
copy_
(
main_view
)
# # Use DDP's contiguous buffer to temporarily hold params.
# model_param = self.models[model_index]._grad_buffers[dtype].data
# main_param = self.get_main_param(group_index)
# # Copy sub-range within tensor.
# model_view = model_param[model_range.start:model_range.end]
# main_view = main_param[main_range.start:main_range.end]
# model_view.detach().copy_(main_view)
# def _copy_main_params_to_model_params(self):
# super()._copy_main_params_to_model_params()
# raise Exception("check main param '.grad'.")
def
_copy_main_params_to_model_params
(
self
):
raise
Exception
(
"hi."
)
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_param
.
main_grad
.
detach
().
copy_
(
main_param
)
# For fp32 grads, we need to reset the grads to main grad.
for
group
in
self
.
fp32_groups
:
for
param
in
group
:
param
.
main_grad
.
detach
().
copy_
(
param
)
# <<<
megatron/optimizer/optimizer.py
View file @
2c3cb9fc
...
...
@@ -323,6 +323,22 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_copy_model_params_to_main_params
()
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
...
...
@@ -552,18 +568,20 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for
group
in
self
.
float16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_float16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment