Commit 0481f58e authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

grad norm debugging.

parent 3d433e8a
......@@ -68,6 +68,24 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# >>>
else:
# from lutil import pax
# pax({"grad": grad})
from megatron import get_args
args = get_args()
for r in range(torch.distributed.get_world_size()):
if torch.distributed.get_rank() == r:
print("collect: r %d, dist-op %d, np %d, ne %d, g %s" % (
torch.distributed.get_rank(),
args.use_distributed_optimizer,
len(parameters),
sum(t.nelement() for t in parameters),
str(tuple(grad.shape)),
))
torch.distributed.barrier()
exit(0)
# <<<
# Norm parameters.
max_norm = float(max_norm)
......@@ -100,6 +118,30 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
# >>>
from megatron import get_args
from lutil import pax
args = get_args()
for r in range(torch.distributed.get_world_size()):
if torch.distributed.get_rank() == r:
print("compute: r %d, dist-op %d, gnorm %f ... p %d, g %d, gn %d" % (
torch.distributed.get_rank(),
args.use_distributed_optimizer,
grad_norm.item(),
sum(t.nelement() for t in parameters),
sum(t.nelement() for t in grads),
sum(t.nelement() for t in grads_for_norm),
))
torch.distributed.barrier()
exit(0)
# pax(2, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
# "grad_norm" : grad_norm.item(),
# "total_norm" : total_norm.item(),
# })
# <<<
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
......@@ -111,6 +153,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
group=model_parallel_group)
total_norm = total_norm.item() ** (1.0 / norm_type)
# >>>
from megatron import get_args
from lutil import pax
args = get_args()
pax(0, {
"use distrib opt" : args.use_distributed_optimizer,
"norm_type" : norm_type,
"total_norm" : total_norm,
})
# <<<
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
......
......@@ -408,7 +408,91 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
# >>>
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def use_grad(p):
conditions = [
p.grad is not None,
param_is_not_shared(p),
param_is_not_tensor_parallel_duplicate(p),
# getattr(p, "shared", False),
]
return all(conditions)
# def print_module(m, d):
# ps = [ "%d/%s" % (
# use_grad(p),
# str(tuple(p.shape)),
# ) for p in m.parameters(recurse = False) ]
# ps = [
# str(tuple(p))
# for p in m.parameters(recurse = False)
# if use_grad(p)
# ]
# print("%s %s | %s" % (".." * d, type(m).__name__, ", ".join(ps)))
# if torch.distributed.get_rank() == 0:
# visited = []
# queue = [ (m, 0) for m in self.models ]
# while queue:
# m, d = queue.pop()
# visited.append((m, d))
# # print_module(m, d)
# queue.extend(reversed([ (mm, d + 1) for mm in m.children() ]))
# for m, d in visited:
# print_module(m, d)
for r in range(torch.distributed.get_world_size()):
if r == torch.distributed.get_rank():
# print("r %d, %s" % (
# torch.distributed.get_rank(),
# "".join(
# "%d" % use_grad(p)
# for m in self.models
# for p in m.parameters()
# ),
# ))
# print("r %d [ d %d, t %d, p %d ] ... %s" % (
# torch.distributed.get_rank(),
# mpu.get_data_parallel_rank(),
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# ", ".join(str(tuple(p.shape)) for p in self.get_parameters() if not use_grad(p)),
# ))
print("r %d [ d %d, t %d, p %d ] ... %d, %d ... %s" % (
torch.distributed.get_rank(),
mpu.get_data_parallel_rank(),
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum(p.nelement()
for p in self.get_parameters()
if use_grad(p)),
sum(p.nelement()
for p in self.get_parameters()
if not use_grad(p)),
"".join(
"%d" % use_grad(p)
for p in self.get_parameters()
),
))
torch.distributed.barrier()
torch.distributed.barrier()
exit(0)
# <<<
grad_norm = self.clip_grad_norm(self.clip_grad)
# >>>
from lutil import pax
pax(0, {
"use distrib opt" : args.use_distributed_optimizer,
"grad_norm" : grad_norm,
})
# <<<
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment