Commit dfd8ed47 authored by mohammad's avatar mohammad
Browse files

grads is removed from mpu

parent 46879674
...@@ -209,7 +209,7 @@ def to_python_float(t): ...@@ -209,7 +209,7 @@ def to_python_float(t):
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm clip_grad_norm = None #mpu.clip_grad_norm
# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: # elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm # clip_grad_norm = torch.nn.utils.clip_grad_norm
# else: # else:
......
...@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy ...@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from .data import broadcast_data
from .grads import clip_grad_norm
from .initialize import is_unitialized from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import torch
from torch._six import inf
try:
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
# Make sure we have an iterable.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters with gradients.
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters_for_norm],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
# Scale to get max_norm.
clip_coef = float(max_norm) / (total_norm + 1.0e-6)
grads = [p.grad for p in parameters_with_grads]
if clip_coef < 1.0:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(get_tensor_model_parallel_rank() == 0)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
filtered_parameters.append(param)
parameters = filtered_parameters
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0
# Calculate norm.
if norm_type == inf:
total_norm = max(param.grad.detach().abs().max()
for param in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
for param in parameters:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
# Scale.
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for param in parameters:
param.grad.detach().mul_(clip_coef)
return total_norm
...@@ -38,7 +38,7 @@ def get_params_for_weight_decay_optimization(module): ...@@ -38,7 +38,7 @@ def get_params_for_weight_decay_optimization(module):
args = get_args() args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module_ in module.modules():
...@@ -63,8 +63,11 @@ def get_megatron_optimizer(model): ...@@ -63,8 +63,11 @@ def get_megatron_optimizer(model):
# Base optimizer. # Base optimizer.
param_groups = get_params_for_weight_decay_optimization(model) param_groups = get_params_for_weight_decay_optimization(model)
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay, optimizer = Adam(param_groups,
betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps) lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
if args.fp16: if args.fp16:
# Constant loss scale. # Constant loss scale.
...@@ -213,7 +216,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -213,7 +216,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
Returns: Returns:
Total norm of the parameters (viewed as a single vector). Total norm of the parameters (viewed as a single vector).
""" """
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
...@@ -247,7 +250,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -247,7 +250,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
for param in parameters: for param in parameters:
param_norm = torch.norm(param.grad.detach(), norm_type) param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment