Commit 6728a780 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

grad norm 'matches' (not bitwise equal).

parent c6d20c05
...@@ -25,7 +25,8 @@ from megatron.model.module import param_is_not_shared ...@@ -25,7 +25,8 @@ from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, def clip_grad_norm_fp32(parameters, grads_for_norm,
max_norm, norm_type=2,
model_parallel_group=None): model_parallel_group=None):
"""Clips gradient norm of an iterable of parameters whose gradients """Clips gradient norm of an iterable of parameters whose gradients
are in fp32. are in fp32.
...@@ -50,42 +51,26 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ...@@ -50,42 +51,26 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
# Filter parameters based on: # >>>
# - grad should not be none # # Filter parameters based on:
# - parameter should not be shared # # - grad should not be none
# - should not be a replica due to tensor model parallelism # # - parameter should not be shared
grads = [] # # - should not be a replica due to tensor model parallelism
grads_for_norm = [] # grads = []
for param in parameters: # grads_for_norm = []
grad_not_none = param.grad is not None # for param in parameters:
is_not_shared = param_is_not_shared(param) # grad_not_none = param.grad is not None
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) # is_not_shared = param_is_not_shared(param)
if grad_not_none: # is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach() # if grad_not_none:
if grad_not_none: # grad = param.grad.detach()
# Make sure the grads are in fp32 # if grad_not_none:
assert param.grad.type() == 'torch.cuda.FloatTensor' # # Make sure the grads are in fp32
grads.append(grad) # assert param.grad.type() == 'torch.cuda.FloatTensor'
if grad_not_none and is_not_shared and is_not_tp_duplicate: # grads.append(grad)
grads_for_norm.append(grad) # if grad_not_none and is_not_shared and is_not_tp_duplicate:
# >>> # grads_for_norm.append(grad)
else: # <<<
# from lutil import pax
# pax({"grad": grad})
from megatron import get_args
args = get_args()
for r in range(torch.distributed.get_world_size()):
if torch.distributed.get_rank() == r:
print("collect: r %d, dist-op %d, np %d, ne %d, g %s" % (
torch.distributed.get_rank(),
args.use_distributed_optimizer,
len(parameters),
sum(t.nelement() for t in parameters),
str(tuple(grad.shape)),
))
torch.distributed.barrier()
exit(0)
# <<<
# Norm parameters. # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
...@@ -118,30 +103,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ...@@ -118,30 +103,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
# we need the pow(norm-type). # we need the pow(norm-type).
total_norm = grad_norm ** norm_type total_norm = grad_norm ** norm_type
# >>>
# from megatron import get_args
# from lutil import pax
# args = get_args()
# for r in range(torch.distributed.get_world_size()):
# if torch.distributed.get_rank() == r:
# print("compute: r %d, dist-op %d, gnorm %f ... p %d, g %d, gn %d" % (
# torch.distributed.get_rank(),
# args.use_distributed_optimizer,
# grad_norm.item(),
# sum(t.nelement() for t in parameters),
# sum(t.nelement() for t in grads),
# sum(t.nelement() for t in grads_for_norm),
# ))
# torch.distributed.barrier()
# exit(0)
# pax(2, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
# "grad_norm" : grad_norm.item(),
# "total_norm" : total_norm.item(),
# })
# <<<
else: else:
for grad in grads_for_norm: for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type) grad_norm = torch.norm(grad, norm_type)
...@@ -154,14 +115,14 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ...@@ -154,14 +115,14 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm = total_norm.item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
# >>> # >>>
# from megatron import get_args from megatron import get_args
# from lutil import pax from lutil import pax
# args = get_args() args = get_args()
# pax(0, { pax(0, {
# "use distrib opt" : args.use_distributed_optimizer, "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type, "norm_type" : norm_type,
# "total_norm" : total_norm, "total_norm" : total_norm,
# }) })
# <<< # <<<
# Scale. # Scale.
......
...@@ -198,7 +198,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -198,7 +198,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# >>> # >>>
@classmethod @classmethod
def get_grad_views_for_grad_norm(cls, opt_group_shards, optimizer): def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer):
grad_views = [] grad_views = []
# grad_views_SKIPPED = [] # grad_views_SKIPPED = []
...@@ -285,7 +285,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -285,7 +285,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# <<< # <<<
# Params for grad norm. # Params for grad norm.
self.grad_views_for_grad_norm = self.get_grad_views_for_grad_norm( self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm(
self.opt_group_shards, self.opt_group_shards,
self.optimizer) self.optimizer)
...@@ -344,6 +344,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -344,6 +344,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_main_grad(self, group_index): def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad return self.get_main_param(group_index).grad
# >>>
def _get_main_grads_for_grad_norm(self):
return self.main_grad_views_for_grad_norm
# <<<
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
......
...@@ -112,6 +112,12 @@ class MegatronOptimizer(ABC): ...@@ -112,6 +112,12 @@ class MegatronOptimizer(ABC):
params.append(param) params.append(param)
return params return params
# >>>
@abstractmethod
# def get_grads_for_norm(self):
def _get_main_grads_for_grad_norm(self):
pass
# <<<
def get_model_parallel_group(self): def get_model_parallel_group(self):
'''Default returned here, but the distributed optimizer overrides this.''' '''Default returned here, but the distributed optimizer overrides this.'''
...@@ -119,20 +125,10 @@ class MegatronOptimizer(ABC): ...@@ -119,20 +125,10 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad): def clip_grad_norm(self, clip_grad):
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# from lutil import pax
# pax(1, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
params = self.get_parameters() params = self.get_parameters()
grads_for_norm = self._get_main_grads_for_grad_norm()
return clip_grad_norm_fp32( return clip_grad_norm_fp32(
params, clip_grad, params, grads_for_norm, clip_grad,
model_parallel_group=self.get_model_parallel_group()) model_parallel_group=self.get_model_parallel_group())
...@@ -574,6 +570,36 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -574,6 +570,36 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def _get_main_grads_for_grad_norm(self):
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params = self.get_parameters()
# grads = []
grads_for_norm = []
for param in params:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
# if grad_not_none:
# grad = param.grad.detach()
# if grad_not_none:
# # Make sure the grads are in fp32
# assert param.grad.type() == 'torch.cuda.FloatTensor'
# grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# pax(0, {"grads_for_norm": [
# str(tuple(g.shape))
# for g in grads_for_norm
# ]})
return grads_for_norm
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
main_grads = [] main_grads = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment