Commit 46879674 authored by mohammad's avatar mohammad
Browse files

clip grad fixed and moved to optimizer

parent 28062e14
...@@ -79,6 +79,7 @@ class PipelinedMegatronModule(MegatronModule): ...@@ -79,6 +79,7 @@ class PipelinedMegatronModule(MegatronModule):
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std)) init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter values. # Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
......
...@@ -72,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -72,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm):
return total_norm return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
...@@ -89,42 +89,43 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): ...@@ -89,42 +89,43 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
Returns: Returns:
Total norm of the parameters (viewed as a single vector). Total norm of the parameters (viewed as a single vector).
""" """
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
if parameter_names is not None:
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters = [] filtered_parameters = []
assert len(parameters) == len(parameter_names), \ for param in parameters:
'length of parameters and parameter_names should be the same' grad_not_none = param.grad is not None
for p, n in zip(parameters, parameter_names): is_not_shared = not hasattr(param, 'shared') or not param.shared
if p.grad is not None: is_not_tp_duplicate = param.tensor_model_parallel or \
# TODO: Bit hacky; is there a cleaner way to do this? (get_tensor_model_parallel_rank() == 0)
# Count embedding layer only once (in first stage). if grad_not_none and is_not_shared and is_not_tp_duplicate:
# Don't count the weights a second time in the last stage. filtered_parameters.append(param)
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters)) # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
total_norm = 0
# Calculate norm.
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm = max(param.grad.detach().abs().max()
for param in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
else: else:
total_norm = 0 for param in parameters:
for p in parameters: param_norm = torch.norm(param.grad.detach(), norm_type)
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
...@@ -132,8 +133,11 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): ...@@ -132,8 +133,11 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type) total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
# Scale.
clip_coef = max_norm / (total_norm + 1e-6) clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1: if clip_coef < 1:
for p in parameters: for param in parameters:
p.grad.data.mul_(clip_coef) param.grad.detach().mul_(clip_coef)
return total_norm return total_norm
...@@ -19,6 +19,7 @@ from abc import ABC ...@@ -19,6 +19,7 @@ from abc import ABC
from abc import abstractmethod from abc import abstractmethod
import torch import torch
from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
...@@ -195,6 +196,77 @@ def _zero_grad_group_helper(group, set_to_none): ...@@ -195,6 +196,77 @@ def _zero_grad_group_helper(group, set_to_none):
param.grad.zero_() param.grad.zero_()
def _clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(mpu.get_tensor_model_parallel_rank() == 0)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
filtered_parameters.append(param)
parameters = filtered_parameters
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(param.grad.detach().abs().max()
for param in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
for param in parameters:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
# Scale.
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for param in parameters:
param.grad.detach().mul_(clip_coef)
return total_norm
class MegatronOptimizer(ABC): class MegatronOptimizer(ABC):
...@@ -203,6 +275,13 @@ class MegatronOptimizer(ABC): ...@@ -203,6 +275,13 @@ class MegatronOptimizer(ABC):
self.optimizer = optimizer self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.' assert self.optimizer, 'no optimizer is provided.'
def clip_grad_norm(self, clip_grad):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
_clip_grad_norm(params, clip_grad)
@abstractmethod @abstractmethod
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
pass pass
...@@ -299,6 +378,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -299,6 +378,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(master_param, mpu.copy_tensor_model_parallel_attributes(master_param,
param) param)
if hasattr(param, 'shared'):
master_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = master_param param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_params_this_group.append(master_param)
...@@ -408,11 +489,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -408,11 +489,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# ========================== # ==========================
timers('optimizer-clip-master-grad').start() timers('optimizer-clip-master-grad').start()
fp32_params = [] self.clip_grad_norm(self.clip_grad)
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
mpu.clip_grad_norm(fp32_params, self.clip_grad)
timers('optimizer-clip-master-grad').stop() timers('optimizer-clip-master-grad').stop()
# =================== # ===================
...@@ -473,13 +550,7 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -473,13 +550,7 @@ class FP32Optimizer(MegatronOptimizer):
# Clip gradients. # Clip gradients.
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
parameters = [] self.clip_grad_norm(self.clip_grad)
parameter_names = []
for parameter_name, parameter in self.model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, self.clip_grad,
parameter_names=parameter_names)
# Update parameters. # Update parameters.
self.optimizer.step() self.optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment