Commit c6d20c05 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

implemented 'get_grad_views_for_grad_norm()'.

parent 0481f58e
......@@ -119,21 +119,21 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm = grad_norm ** norm_type
# >>>
from megatron import get_args
from lutil import pax
args = get_args()
for r in range(torch.distributed.get_world_size()):
if torch.distributed.get_rank() == r:
print("compute: r %d, dist-op %d, gnorm %f ... p %d, g %d, gn %d" % (
torch.distributed.get_rank(),
args.use_distributed_optimizer,
grad_norm.item(),
sum(t.nelement() for t in parameters),
sum(t.nelement() for t in grads),
sum(t.nelement() for t in grads_for_norm),
))
torch.distributed.barrier()
exit(0)
# from megatron import get_args
# from lutil import pax
# args = get_args()
# for r in range(torch.distributed.get_world_size()):
# if torch.distributed.get_rank() == r:
# print("compute: r %d, dist-op %d, gnorm %f ... p %d, g %d, gn %d" % (
# torch.distributed.get_rank(),
# args.use_distributed_optimizer,
# grad_norm.item(),
# sum(t.nelement() for t in parameters),
# sum(t.nelement() for t in grads),
# sum(t.nelement() for t in grads_for_norm),
# ))
# torch.distributed.barrier()
# exit(0)
# pax(2, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
......@@ -154,14 +154,14 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm = total_norm.item() ** (1.0 / norm_type)
# >>>
from megatron import get_args
from lutil import pax
args = get_args()
pax(0, {
"use distrib opt" : args.use_distributed_optimizer,
"norm_type" : norm_type,
"total_norm" : total_norm,
})
# from megatron import get_args
# from lutil import pax
# args = get_args()
# pax(0, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
# "total_norm" : total_norm,
# })
# <<<
# Scale.
......
......@@ -22,9 +22,17 @@ import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu
# >>>
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# <<<
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# >>>
from .optimizer import get_clippy
from lutil import pax, tp
# <<<
class Shard:
def __init__(self, start, end):
......@@ -188,6 +196,45 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Update group's param.
group_shard["orig_group"]["params"] = [ main_param ]
# >>>
@classmethod
def get_grad_views_for_grad_norm(cls, opt_group_shards, optimizer):
grad_views = []
# grad_views_SKIPPED = []
for group_index, opt_group_shard in enumerate(opt_group_shards):
opt_grad = optimizer.param_groups[group_index]["params"][0].grad
for param, shard in opt_group_shard["param_map"].items():
if param_is_not_shared(param) and \
param_is_not_tensor_parallel_duplicate(param):
grad_view = opt_grad[shard.start:shard.end]
grad_views.append(grad_view)
# else:
# grad_views_SKIPPED.append(opt_grad[shard.start:shard.end])
# >>>
# my_rank = torch.distributed.get_rank()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, grad views %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views),
# ))
# torch.distributed.barrier()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, SKIPPED %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views_SKIPPED),
# ))
# torch.distributed.barrier()
# exit(0)
# <<<
return grad_views
# <<<
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
......@@ -227,6 +274,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Initialize main params.
self._copy_model_params_to_main_params()
# >>> numel/nelem per rank >>>
# for r in range(torch.distributed.get_world_size()):
# if r == torch.distributed.get_rank():
# for m in self.models:
# for b in m._grad_buffers.values():
# print("r %d, %d." % (r, b.data.nelement()))
# torch.distributed.barrier()
# exit(0)
# <<<
# Params for grad norm.
self.grad_views_for_grad_norm = self.get_grad_views_for_grad_norm(
self.opt_group_shards,
self.optimizer)
def get_model_parallel_group(self):
return None
......@@ -407,6 +470,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers):
timers('backward-params-all-gather').start()
......
......@@ -31,6 +31,20 @@ from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from lutil import pax
get_clippy = lambda params : [ "%d, %d, %d ... %s" % (
p.grad is not None,
param_is_not_shared(p),
param_is_not_tensor_parallel_duplicate(p),
str(tuple(p.shape)),
) for p in params ]
# <<<
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
......@@ -105,6 +119,17 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad):
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# from lutil import pax
# pax(1, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
params = self.get_parameters()
return clip_grad_norm_fp32(
params, clip_grad,
......@@ -408,91 +433,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
# >>>
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def use_grad(p):
conditions = [
p.grad is not None,
param_is_not_shared(p),
param_is_not_tensor_parallel_duplicate(p),
# getattr(p, "shared", False),
]
return all(conditions)
# def print_module(m, d):
# ps = [ "%d/%s" % (
# use_grad(p),
# str(tuple(p.shape)),
# ) for p in m.parameters(recurse = False) ]
# ps = [
# str(tuple(p))
# for p in m.parameters(recurse = False)
# if use_grad(p)
# ]
# print("%s %s | %s" % (".." * d, type(m).__name__, ", ".join(ps)))
# if torch.distributed.get_rank() == 0:
# visited = []
# queue = [ (m, 0) for m in self.models ]
# while queue:
# m, d = queue.pop()
# visited.append((m, d))
# # print_module(m, d)
# queue.extend(reversed([ (mm, d + 1) for mm in m.children() ]))
# for m, d in visited:
# print_module(m, d)
for r in range(torch.distributed.get_world_size()):
if r == torch.distributed.get_rank():
# print("r %d, %s" % (
# torch.distributed.get_rank(),
# "".join(
# "%d" % use_grad(p)
# for m in self.models
# for p in m.parameters()
# ),
# ))
# print("r %d [ d %d, t %d, p %d ] ... %s" % (
# torch.distributed.get_rank(),
# mpu.get_data_parallel_rank(),
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# ", ".join(str(tuple(p.shape)) for p in self.get_parameters() if not use_grad(p)),
# ))
print("r %d [ d %d, t %d, p %d ] ... %d, %d ... %s" % (
torch.distributed.get_rank(),
mpu.get_data_parallel_rank(),
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum(p.nelement()
for p in self.get_parameters()
if use_grad(p)),
sum(p.nelement()
for p in self.get_parameters()
if not use_grad(p)),
"".join(
"%d" % use_grad(p)
for p in self.get_parameters()
),
))
torch.distributed.barrier()
torch.distributed.barrier()
exit(0)
# <<<
grad_norm = self.clip_grad_norm(self.clip_grad)
# >>>
from lutil import pax
pax(0, {
"use distrib opt" : args.use_distributed_optimizer,
"grad_norm" : grad_norm,
})
# <<<
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
......@@ -607,6 +548,17 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# model_params.sort(key = lambda p : p.nelement(), reverse = True)
# optim_params.sort(key = lambda p : p.nelement(), reverse = True)
# # assert len(model_params) == len(optim_params
# pax(7, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment