Commit 9a297541 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'optimize_refactor' into 'main'

Optimizer refactor

See merge request ADLR/megatron-lm!197
parents b69bc7ef 42c6d8e7
......@@ -4,7 +4,7 @@ import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from megatron.module import MegatronModule
from .module import MegatronModule
from megatron import mpu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......
......@@ -21,7 +21,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
......
......@@ -20,7 +20,6 @@ import math
import torch
from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
......@@ -60,28 +59,3 @@ def openai_gelu(x):
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
......@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .grads import clip_grad_norm
from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
......@@ -46,7 +44,9 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
......
......@@ -37,14 +37,48 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
......@@ -58,9 +92,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from apex.optimizers import FusedAdam as Adam
from megatron import get_args
from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def get_megatron_optimizer(model):
args = get_args()
# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
if args.fp16:
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
args.clip_grad)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad)
......@@ -13,67 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
"""Gradient clipping."""
import torch
from torch._six import inf
try:
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from megatron import mpu
def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
# Make sure we have an iterable.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters with gradients.
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters_for_norm],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
# Scale to get max_norm.
clip_coef = float(max_norm) / (total_norm + 1.0e-6)
grads = [p.grad for p in parameters_with_grads]
if clip_coef < 1.0:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters.
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
......@@ -89,51 +42,78 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(mpu.get_tensor_model_parallel_rank() == 0)
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
else:
total_norm = 0
for p in parameters:
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type
if norm_type == 2.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
group=mpu.get_model_parallel_group())
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron grad scaler."""
from abc import ABC
from abc import abstractmethod
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.cuda.FloatTensor([initial_scale])
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
def __init__(self, initial_scale, min_scale,
growth_factor, backoff_factor,
growth_interval, hysteresis):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.cuda.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron optimizer."""
from abc import ABC
from abc import abstractmethod
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for param in group:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another."""
if overflow_buf:
overflow_buf.fill_(0)
else:
overflow_buf = torch.cuda.IntTensor([0])
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
class MegatronOptimizer(ABC):
def __init__(self, optimizer):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
def clip_grad_norm(self, clip_grad):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
clip_grad_norm_fp32(params, clip_grad)
@abstractmethod
def zero_grad(self, set_to_none=True):
pass
@abstractmethod
def get_loss_scale(self):
"""The output should be a cuda tensor of size 1."""
pass
def scale_loss(self, loss):
"""Simple scaling."""
return self.get_loss_scale() * loss
@abstractmethod
def step(self):
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
Call whenever the parameters are changed outside of the optimizer.
For example, when we load a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer
with main parameters, the main parameters need to also be updated."""
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
class FP16OptimizerWithFP16Params(MegatronOptimizer):
def __init__(self, optimizer, grad_scaler, clip_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer)
self.grad_scaler = grad_scaler
self.clip_grad = clip_grad
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor.
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# fp16_groups: original fp16 parameters
# fp32_from_fp16_groups: fp32 copy of fp16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.fp16_groups = []
self.fp32_from_fp16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# fp16 params:
if param.type() == 'torch.cuda.HalfTensor':
fp16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Store grads
main_param.requires_grad = True
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param,
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or "
"torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
fp16_groups & fp32_from_fp32_groups."""
for group in self.fp16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self):
return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the fp16 group.
model_grads = []
main_grads = []
for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, main_param in zip(model_group, main_group):
if model_param.grad is not None:
if main_param.grad is None:
main_param.grad = torch.empty_like(main_param)
model_grads.append(model_param.grad.data)
main_grads.append(main_param.grad.data)
_multi_tensor_copy_this_to_that(this=model_grads, that=main_grads,
overflow_buf=self._dummy_overflow_buf)
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
# fp32 params fromm fp16 ones.
for main_group in self.fp32_from_fp16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag
def _get_model_and_main_params_data_fp16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self):
# Only needed for the fp16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self):
# Only needed for the fp16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf)
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
timers = get_timers()
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# Step the optimizer.
self.optimizer.step()
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
return state_dict
def load_state_dict(self, state_dict):
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Copy data for the main params.
fp32_from_fp16_params_key = 'fp32_from_fp16_params'
if fp32_from_fp16_params_key not in state_dict:
fp32_from_fp16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip(
self.fp32_from_fp16_groups,
state_dict[fp32_from_fp16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad):
super(FP32Optimizer, self).__init__(optimizer)
self.clip_grad = clip_grad
self._scale = torch.cuda.FloatTensor([1.0])
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
def get_loss_scale(self):
"""FP32 optimizer does not do any scaling."""
return self._scale
@torch.no_grad()
def step(self):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Clip gradients.
if self.clip_grad > 0.0:
self.clip_grad_norm(self.clip_grad)
# Update parameters.
self.optimizer.step()
# No overflow for FP32 optimizer.
return True
def reload_model_params(self):
pass
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
......@@ -24,7 +24,6 @@ _TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
from megatron import get_args
from megatron import get_timers
......@@ -38,13 +37,13 @@ from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.model import FP16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
......@@ -183,6 +182,13 @@ def get_model(model_provider_func):
# Build model on cpu.
model = model_provider_func()
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for param in model.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
......@@ -196,7 +202,7 @@ def get_model(model_provider_func):
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
model = FP16Module(model)
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
......@@ -211,38 +217,6 @@ def get_model(model_provider_func):
'Exiting.'.format(args.DDP_impl))
def get_optimizer(model):
"""Set up the optimizer."""
args = get_args()
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
model = model.module
param_groups = get_params_for_weight_decay_optimization(model)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'tensor_model_parallel'):
param.tensor_model_parallel = False
# Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale': args.min_scale,
'delayed_shift': args.hysteresis})
return optimizer
def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
......@@ -291,7 +265,12 @@ def setup_model_and_optimizer(model_provider_func):
args = get_args()
model = get_model(model_provider_func)
optimizer = get_optimizer(model)
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None:
......@@ -382,11 +361,9 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
input_tensor.retain_grad()
# Backward pass.
if args.fp16:
optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad)
else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
......@@ -605,10 +582,7 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers()
# Set grad to zero.
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else:
optimizer.zero_grad()
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_pipelining(
......@@ -632,7 +606,7 @@ def train_step(forward_step_func, data_iterator,
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module
if unwrapped_model.share_word_embeddings:
......@@ -641,40 +615,18 @@ def train_step(forward_step_func, data_iterator,
group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update master gradients.
timers('backward-master-grad').start()
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
if args.clip_grad > 0.:
if not args.fp16:
named_parameters = model.named_parameters()
parameters = []
parameter_names = []
for parameter_name, parameter in model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else:
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
update_successfull = optimizer.step()
timers('optimizer').stop()
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
if update_successfull:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
......@@ -738,10 +690,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-master-grad')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
......@@ -764,10 +718,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if args.fp16:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
......@@ -793,8 +746,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
......@@ -858,9 +810,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
get_num_microbatches()
# Logging.
loss_scale = None
if args.fp16:
loss_scale = optimizer.loss_scale
loss_scale = optimizer.get_loss_scale().item()
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
......
......@@ -24,7 +24,6 @@ from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Optimizer
def average_losses_across_data_parallel_group(losses):
......@@ -47,11 +46,13 @@ def report_memory(name):
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
string += ' | reserved: {}'.format(
torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
flush=True)
def print_params_min_max_norm(optimizer, iteration):
......@@ -59,9 +60,7 @@ def print_params_min_max_norm(optimizer, iteration):
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
......
......@@ -186,7 +186,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Logging.
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter)
# Autoresume
......@@ -255,9 +256,8 @@ def finetune(train_valid_datasets_provider, model_provider,
_ = load_checkpoint(model, None, None)
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
if args.fp16:
optimizer._model_params_to_master_params()
# main parameters are also updated.
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment