Commit b1781f31 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

correctly copying tensor parallel attributes for fp32_groups.

parent d0925652
......@@ -99,6 +99,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
group=model_parallel_group)
total_norm = total_norm.item() ** (1.0 / norm_type)
# >>>
from lutil import pax, tp, print_seq
print_seq("norm : grad %s, total %s." % (grad_norm.item(), total_norm))
# <<<
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
......
......@@ -235,6 +235,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[param_range.start:param_range.end]
full_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param)
else:
raise TypeError('Wrapped parameters must be one of '
......
......@@ -122,6 +122,14 @@ class MegatronOptimizer(ABC):
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# >>>
# from lutil import pax
# pax(0, {
# "params" : params,
# "grads_for_norm" : grads_for_norm,
# })
# <<<
return grads_for_norm
......@@ -133,6 +141,16 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm()
# >>>
from lutil import print_seq
# print_seq("params %d, ngrads %d." % (len(params), len(grads_for_norm)))
# print_seq([
# "grads_for_norm / %d = %s." % (i, str(tuple(g.shape)))
# for i, g in enumerate(grads_for_norm)
# ])
print_seq("grads_for_norm = %s." % ", ".join(
str(tuple(g.shape)) for g in grads_for_norm))
# <<<
return clip_grad_norm_fp32(
params, grads_for_norm, clip_grad,
model_parallel_group=self.get_model_parallel_group())
......@@ -295,7 +313,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
assert not self.fp16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment