Commit 46879674 authored by mohammad's avatar mohammad
Browse files

clip grad fixed and moved to optimizer

parent 28062e14
......@@ -79,6 +79,7 @@ class PipelinedMegatronModule(MegatronModule):
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
......
......@@ -72,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm):
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
......@@ -89,51 +89,55 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(get_tensor_model_parallel_rank() == 0)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
filtered_parameters.append(param)
parameters = filtered_parameters
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm = max(param.grad.detach().abs().max()
for param in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
else:
total_norm = 0
for p in parameters:
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type
else:
for param in parameters:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
# Scale.
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for param in parameters:
param.grad.detach().mul_(clip_coef)
return total_norm
......@@ -19,6 +19,7 @@ from abc import ABC
from abc import abstractmethod
import torch
from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
from apex.optimizers import FusedAdam as Adam
......@@ -195,6 +196,77 @@ def _zero_grad_group_helper(group, set_to_none):
param.grad.zero_()
def _clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(mpu.get_tensor_model_parallel_rank() == 0)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
filtered_parameters.append(param)
parameters = filtered_parameters
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(param.grad.detach().abs().max()
for param in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
for param in parameters:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
# Scale.
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for param in parameters:
param.grad.detach().mul_(clip_coef)
return total_norm
class MegatronOptimizer(ABC):
......@@ -203,6 +275,13 @@ class MegatronOptimizer(ABC):
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
def clip_grad_norm(self, clip_grad):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
_clip_grad_norm(params, clip_grad)
@abstractmethod
def zero_grad(self, set_to_none=True):
pass
......@@ -299,6 +378,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(master_param,
param)
if hasattr(param, 'shared'):
master_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
......@@ -408,11 +489,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# ==========================
timers('optimizer-clip-master-grad').start()
fp32_params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
mpu.clip_grad_norm(fp32_params, self.clip_grad)
self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-master-grad').stop()
# ===================
......@@ -473,13 +550,7 @@ class FP32Optimizer(MegatronOptimizer):
# Clip gradients.
if self.clip_grad > 0.0:
parameters = []
parameter_names = []
for parameter_name, parameter in self.model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, self.clip_grad,
parameter_names=parameter_names)
self.clip_grad_norm(self.clip_grad)
# Update parameters.
self.optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment