Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
91f3579e
Commit
91f3579e
authored
Mar 24, 2022
by
Lawrence McAfee
Browse files
cleanup.
parent
4b9a218e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
172 deletions
+21
-172
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+21
-172
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
91f3579e
...
...
@@ -26,16 +26,12 @@ from megatron.model.module import param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
# from .optimizer import Float16OptimizerWithFloat16Params
# >>>
from
lutil
import
pax
,
tp
,
print_seq
# <<<
# >>>
# class Shard:
class
Range
:
# <<<
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
...
...
@@ -47,7 +43,6 @@ class Range:
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
# class DistributedOptimizer(Float16OptimizerWithFloat16Params):
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
@
classmethod
...
...
@@ -82,6 +77,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return
param_range_map
@
classmethod
def
build_model_gbuf_range
(
cls
,
model
,
dtype
):
...
...
@@ -121,6 +117,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return
data
@
classmethod
def
build_model_gbuf_range_map
(
cls
,
model
):
return
{
...
...
@@ -128,6 +125,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
dtype
in
model
.
_grad_buffers
}
@
classmethod
def
build_model_param_gbuf_map
(
cls
,
model_gbuf_ranges
):
'''Create a reverse of the model_gbuf_ranges, for referencing in
...
...
@@ -139,42 +137,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map
[
param
]
=
(
model_index
,
dtype
)
return
param_gbuf_map
# >>>
# @classmethod
# def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
# num_groups = len(param_groups)
# # Param group map.
# param_group_map = {}
# for group_index, group in enumerate(param_groups):
# for param in group["params"]:
# assert param.requires_grad
# param_group_map[param] = group_index
# # Optimizer group ranges.
# group_ranges = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_range_map in model_gbuf_ranges:
# for dtype, gbuf_range_map in model_gbuf_range_map.items():
# for param in gbuf_range_map["param_map"]:
# group_index = param_group_map[param]
# group_range = group_ranges[group_index]
# param_size = gbuf_range_map["param_map"][param]["param"].size
# param_group_start = group_range["size"]
# param_group_end = param_group_start + param_size
# param_group_range = Range(param_group_start, param_group_end)
# group_range["size"] += param_size
# group_range["param_map"][param] = param_group_range
# # Squeeze zero-size group ranges.
# for group_index, group_range in enumerate(group_ranges):
# group_range["orig_group"] = param_groups[group_index]
# group_ranges = [ g for g in group_ranges if g["size"] > 0 ]
# return group_ranges
@
classmethod
def
build_optimizer_group_ranges
(
cls
,
param_groups
,
model_gbuf_ranges
):
...
...
@@ -291,6 +254,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_fp32_from_float16_groups
,
)
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
):
...
...
@@ -302,11 +266,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py
# >>>
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
assert
use_contiguous_buffers_in_local_ddp
# <<<
# Model grad buffer ranges.
self
.
model_gbuf_ranges
=
[]
...
...
@@ -331,12 +291,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
model_param_gbuf_map
,
self
.
opt_group_ranges
)
# print_seq("16 [%d], 16x32 [%d], 32 [%d]." % (
# sum(len(g) for g in self.float16_groups),
# sum(len(g) for g in self.fp32_from_float16_groups),
# sum(len(g) for g in self.fp32_groups),
# ))
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
...
...
@@ -344,34 +298,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_ranges
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# # Initialize main params.
# self._copy_model_params_to_main_params()
# <<<
# >>>
# # Params for grad norm.
# self.main_grad_views_for_grad_norm = self.build_main_grad_views_for_grad_norm(
# self.opt_group_ranges,
# self.optimizer)
# <<<
def
get_model_param_range_map
(
self
,
param
):
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
# >>>
# pax(0, {
# "param" : param,
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf_range_map" : gbuf_range_map,
# "param_range_map" : param_range_map,
# })
# <<<
return
param_range_map
...
...
@@ -379,28 +310,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return
None
# def get_main_params(self):
# return [ g["params"][0] for g in self.optimizer.param_groups ]
# def get_main_grads(self):
# return [ p.grad for p in self.get_main_params() ]
# def get_main_param(self, group_index):
# return self.get_main_params()[group_index]
# def get_main_grad(self, group_index):
# return self.get_main_param(group_index).grad
# >>>
# def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
# def get_main_grads_for_grad_norm(self):
# raise Exception("....... use 'super' .......")
# grads_for_norm = super().get_main_grads_for_grad_norm()
# if torch.distributed.get_rank() == 1:
# print_seq([ tp(g) for g in grads_for_norm ])
# return grads_for_norm
# <<<
# def state_dict(self):
# state_dict = {}
# state_dict['optimizer'] = self.optimizer.state_dict()
...
...
@@ -410,8 +320,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# return state_dict
def
state_dict
(
self
):
raise
Exception
(
"fix me."
)
# <<<
# >>>
# def load_state_dict(self, state_dict):
# # Optimizer.
# optimizer_key = 'optimizer'
...
...
@@ -441,20 +353,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# current_param.data.copy_(saved_param.data)
def
load_state_dict
(
self
,
state_dict
):
raise
Exception
(
"hi."
)
# <<<
# >>>
# def zero_grad(self, set_to_none=True):
# # Collect model params.
# model_params = []
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# model_params.extend(param_map.keys())
# # Distributed optimizer requires contiguous buffer; don't set to None.
# _zero_grad_group_helper(model_params, set_to_none = False)
# def zero_grad(self, set_to_none=True):
# raise Exception("does 'super' work?")
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero
...
...
@@ -469,7 +370,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
shard_fp32_from_float16_groups
):
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
# <<<
def
get_model_grad_buffer_dp_views
(
self
):
...
...
@@ -489,6 +389,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return
gbuf_view_items
def
reduce_model_grads
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
...
...
@@ -522,6 +423,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
...
...
@@ -552,55 +454,27 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'backward-params-all-gather'
).
stop
()
# >>>
# def _collect_main_grad_data_for_unscaling(self):
# return [ g.data for g in self.get_main_grads() ]
def
_collect_main_grad_data_for_unscaling
(
self
):
main_grad_data
=
[
return
[
param
.
grad
.
data
for
group
in
self
.
optimizer
.
param_groups
for
param
in
group
[
"params"
]
]
# print_seq([ tp(g) for g in main_grad_data ])
return
main_grad_data
# <<<
# >>>
# def _copy_model_params_to_main_params(self):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# main_param = self.get_main_param(group_index)
# for model_param, main_range in group_range["param_map"].items():
# # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["param"]
# model_range = self.get_model_param_range_map(model_param)["param"]
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
shard_float16_groups
,
self
.
shard_fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
# assert main_range.size == model_range.size
# # Copy shard data.
# main_view = main_param[main_range.start:main_range.end]
# model_view = model_param.view(-1)[model_range.start:model_range.end]
# main_view.detach().copy_(model_view)
def
_copy_model_params_to_main_params
(
self
):
raise
Exception
(
"check if super's copy works."
)
# <<<
# >>>
def
_copy_model_grads_to_main_grads
(
self
):
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
def
copy_group_grads
(
full_model_groups
,
shard_main_groups
):
for
full_model_group
,
shard_main_group
in
zip
(
full_model_groups
,
shard_main_groups
):
...
...
@@ -616,28 +490,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# print_seq([ "%s / %d, [%d] %s" % (
# k, i, len(g), ", ".join(str(p.nelement()) for p in g),
# ) for k, gs in [
# ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
copy_group_grads
(
self
.
full_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_grads
(
self
.
full_fp32_groups
,
self
.
shard_fp32_groups
)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<<
# >>>
def
_copy_main_params_to_model_params
(
self
):
def
copy_group_params
(
shard_main_groups
,
full_model_groups
):
...
...
@@ -660,12 +518,3 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
full_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
full_fp32_groups
)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<<
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment