Commit 5e29905f authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

stalling in copy_grads().

parent 2c3cb9fc
......@@ -447,8 +447,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# >>>
# def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
def get_main_grads_for_grad_norm(self):
raise Exception("does 'super' work?")
# def get_main_grads_for_grad_norm(self):
# raise Exception("....... use 'super' .......")
# grads_for_norm = super().get_main_grads_for_grad_norm()
# if torch.distributed.get_rank() == 1:
# print_seq([ tp(g) for g in grads_for_norm ])
# return grads_for_norm
# <<<
......@@ -493,6 +497,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def load_state_dict(self, state_dict):
raise Exception("hi.")
# >>>
# def zero_grad(self, set_to_none=True):
# # Collect model params.
......@@ -505,7 +510,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# _zero_grad_group_helper(model_params, set_to_none = False)
# def zero_grad(self, set_to_none=True):
# raise Exception("does 'super' work?")
# >>>
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero
......@@ -515,6 +519,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for groups in (
self.full_float16_groups,
self.full_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups):
for group in groups:
_zero_grad_group_helper(group, set_to_none)
......@@ -550,6 +556,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# for m in self.models
# for b in m._grad_buffers.values()
# ])
# print_seq("hi.")
# <<<
# All-reduce embedding grads.
......@@ -577,6 +584,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group,
)
# >>>
# print_seq("hi.")
# <<<
timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers):
......@@ -610,9 +621,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-gather').stop()
# >>>
# def _collect_main_grad_data_for_unscaling(self):
# return [ g.data for g in self.get_main_grads() ]
def _collect_main_grad_data_for_unscaling(self):
raise Exception("hi.")
return [ g.data for g in self.get_main_grads() ]
main_grad_data = [
param.grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
# print_seq([ tp(g) for g in main_grad_data ])
return main_grad_data
# <<<
# >>>
# def _copy_model_params_to_main_params(self):
......@@ -678,44 +700,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# ])
# <<<
# This only needs to be done for the float16 group.
for full_model_group, shard_main_group in zip(
self.full_float16_groups,
self.shard_fp32_from_float16_groups):
for full_model_param, shard_main_param in zip(full_model_group,
shard_main_group):
param_range_map = self.get_model_param_range_map(full_model_param)
param_range = param_range_map["param"]
full_model_grad = full_model_param.main_grad
shard_model_grad = \
full_model_grad[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# >>>
if full_model_param.nelement() != shard_main_param.nelement():
pax(0, {
"param_range_map" : param_range_map,
"param_range" : param_range,
"full_model_param" : tp(full_model_param),
"full_model_grad" : tp(full_model_grad),
"shard_model_grad" : tp(shard_model_grad),
"shard_main_grad" : tp(shard_main_param.grad),
"shard_main_param" : tp(shard_main_param),
})
# <<<
# For fp32 grads, we need to reset the grads to main grad.
for group in self.fp32_groups:
for param in group:
param.grad = param.main_grad
def copy_group_grads(full_model_groups, shard_main_groups):
for full_model_group, shard_main_group in zip(full_model_groups,
shard_main_groups):
for full_model_param, shard_main_param in zip(full_model_group,
shard_main_group):
param_range_map = self.get_model_param_range_map(full_model_param)
param_range = param_range_map["param"]
full_model_grad = full_model_param.main_grad
shard_model_grad = \
full_model_grad[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# >>>
if full_model_param.nelement() != shard_main_param.nelement():
pax(0, {
"param_range_map" : param_range_map,
"param_range" : param_range,
"full_model_param" : tp(full_model_param),
"full_model_grad" : tp(full_model_grad),
"shard_model_grad" : tp(shard_model_grad),
"shard_main_grad" : tp(shard_main_param.grad),
"shard_main_param" : tp(shard_main_param),
})
# <<<
# print_seq("float16 groups: %d [%s], %d [%s]." % (
# len(self.full_float16_groups),
# # ",".join(str(len(g)) for g in self.full_float16_groups),
# ",".join(str(tuple(p.shape)) for gs in self.full_float16_groups for g in gs for p in g),
# len(self.shard_fp32_from_float16_groups),
# ",".join(str(len(g)) for g in self.shard_fp32_from_float16_groups),
# ))
gs = self.full_float16_groups
pax(0, {
**{"gs / %d" % i : len(g) for i, g in enumerate(gs)},
})
copy_group_grads(self.full_float16_groups,
self.shard_fp32_from_float16_groups)
print_seq("hi.")
copy_group_grads(self.full_fp32_groups,
self.shard_fp32_groups)
# >>>
print_seq([
"grad = %s." % tp(p.grad)
for g in self.optimizer.param_groups
for p in g["params"]
])
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<<
......
......@@ -33,6 +33,10 @@ from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp, print_seq
# <<<
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
......@@ -427,6 +431,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
print_seq("hi.")
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
......@@ -569,18 +574,18 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
# <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment