Commit aed2f75e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into github-main

parents 8aa4619f f32a638d
...@@ -19,7 +19,8 @@ import torch ...@@ -19,7 +19,8 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -27,46 +28,57 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -27,46 +28,57 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
class ClassificationBase(MegatronModule): class Classification(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2): def __init__(self,
super(ClassificationBase, self).__init__(share_word_embeddings=False) num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(Classification, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
self.num_classes = num_classes self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage(): if self.post_process:
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size, self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes, self.num_classes,
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
kwargs = {} if self.post_process:
if mpu.is_pipeline_first_stage():
input_ids = model_input
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output _, pooled_output = lm_output
classification_output = self.classification_dropout(pooled_output) classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output) classification_logits = self.classification_head(classification_output)
...@@ -86,7 +98,7 @@ class ClassificationBase(MegatronModule): ...@@ -86,7 +98,7 @@ class ClassificationBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._classification_head_key] \ state_dict_[self._classification_head_key] \
= self.classification_head.state_dict( = self.classification_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -97,7 +109,7 @@ class ClassificationBase(MegatronModule): ...@@ -97,7 +109,7 @@ class ClassificationBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self._classification_head_key in state_dict: if self._classification_head_key in state_dict:
self.classification_head.load_state_dict( self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict) state_dict[self._classification_head_key], strict=strict)
...@@ -105,55 +117,3 @@ class ClassificationBase(MegatronModule): ...@@ -105,55 +117,3 @@ class ClassificationBase(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._classification_head_key)) self._classification_head_key))
class Classification(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(Classification, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationFirstStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationFirstStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(ClassificationFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationIntermediateStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationIntermediateStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationIntermediateStage, self).forward(
hidden_state,
attention_mask)
class ClassificationLastStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationLastStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -13,100 +13,206 @@ ...@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC
from abc import abstractmethod
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
class DistributedDataParallel(MegatronModule):
def __init__(self, module): class MemoryBuffer:
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False def __init__(self, numel, dtype):
self.numel = numel
self.dtype = dtype
self.data = torch.zeros(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
class DistributedDataParallelBase(MegatronModule, ABC):
"""Abstract class for DDP."""
def __init__(self, module):
super(DistributedDataParallelBase, self).__init__()
# Keep a pointer to the model.
self.module = module self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): @abstractmethod
if(self.needs_reduction): def allreduce_gradients(self):
self.needs_reduction = False pass
buckets = {}
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles] return self.module.state_dict(destination, prefix, keep_vars)
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix, return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars) keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers()) class DistributedDataParallel(DistributedDataParallelBase):
if len(buffers) > 0: """DDP with contiguous buffers options to storre and accumulate gradients.
# cross-node buffer sync This class:
flat_buffers = _flatten_dense_tensors(buffers) - has the potential to reduce memory fragmentation.
dist.broadcast(flat_buffers, 0) - provides the option to do the gradient accumulation
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): in a type other than the params type (for example fp32)
buf.copy_(synced)
def train(self, mode=True): Arguments:
# Clear NCCL communicator and CUDA event cache of the default group ID, module: input model.
# These cache will be recreated at the later call. This is currently a accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
# work-around for a potential NCCL deadlock. and the gradient all-reduce all in in float32. If this option is
if dist._backend == dist.dist_backend.NCCL: true, we require `use_contiguous_buffers` to be true too.
dist._clear_group_cache() use_contiguous_buffers: if true, use a contiguous buffer to store the
super(DistributedDataParallel, self).train(mode) gradients.
self.module.train(mode) """
'''
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
super(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if self.accumulate_allreduce_grads_in_fp32:
assert self.use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
# Simple function to define buffer type.
def _get_buffer_type(param):
return torch.float if \
self.accumulate_allreduce_grads_in_fp32 else param.dtype
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator functtion.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for _, buffer_ in self._grad_buffers.items():
buffer_.zero()
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
for _, buffer_ in self._grad_buffers.items():
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_data_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
...@@ -15,29 +15,23 @@ ...@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex: """This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex https://github.com/NVIDIA/apex
with minor changes. """ with some changes. """
import math
import torch
import numbers import numbers
import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F
import importlib import importlib
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
global fused_mix_prec_layer_norm_cuda
if fused_mix_prec_layer_norm_cuda is None:
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda")
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.eps = eps ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
...@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod return grad_input, grad_weight, grad_bias, None, None
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, mean, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
ctx.eps)
return grad_input, None, None
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction.apply(input, normalized_shape, eps)
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ . def __init__(self, normalized_shape, eps=1e-5):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module("fused_mix_prec_layer_norm_cuda") fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.weight = Parameter(torch.Tensor(*normalized_shape))
if self.elementwise_affine: self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine: def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias) init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda:
return F.layer_norm( def forward(self, input):
input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormAffineFunction.apply(
if self.elementwise_affine: input, self.weight, self.bias, self.normalized_shape,self.eps)
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
else:
return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
...@@ -14,114 +14,151 @@ ...@@ -14,114 +14,151 @@
# limitations under the License. # limitations under the License.
import torch import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models). 2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
scaled_upper_triang_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply the mask. 2. Apply the mask.
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, mask, scale): def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_masked_softmax_cuda.forward(
scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_masked_softmax_cuda.backward(
scaled_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(torch.nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking. attn_mask_type: attention mask type (pad or causal)
(used in gpt family networks) mask_func: mask function to be applied.
mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
""" """
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale): def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion self.input_in_bf16 = input_in_bf16
self.general_mask_fusion = general_mask_fusion assert not (self.input_in_fp16 and self.input_in_bf16),\
'both fp16 and bf16 flags cannot be active at the same time.'
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert self.scale is None or softmax_in_fp32, \ assert (
'softmax should be in fp32 when scaled' self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, s, s] # [b, np, sq, sk]
assert input.dim() == 4
data_size = input.size() data_size = input.size()
assert input.dim() == 4 query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
attn_batch_size = data_size[0] * data_size[1]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \ if self.input_in_float16 and mask is not None and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \ custom_kernel_constraint and self.scaled_masked_softmax_fusion:
input.size()[2] == input.size()[3]: scale = self.scale if self.scale is not None else 1.0
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion: if self.attn_mask_type == AttnMaskType.causal:
input = input.view(-1, data_size[2], data_size[3]) assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size) probs = probs.view(*data_size)
else: else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale) probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
input = input.float() input = input.float()
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
mask_output = self.mask_func(input, mask) mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
probs = probs.half() if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs return probs
...@@ -21,17 +21,13 @@ from megatron import get_args ...@@ -21,17 +21,13 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
def gpt_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, get_key_value, parallel_output,
forward_method_parallel_output, forward_method_parallel_output,
...@@ -61,40 +57,50 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -61,40 +57,50 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss return loss
class GPTModelBase(MegatronModule): class GPTModel(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
super(GPTModelBase, self).__init__() num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
super(GPTModel, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt_model_input, attention_mask, labels=None, def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} lm_output = self.language_model(
if mpu.is_pipeline_first_stage(): input_ids,
(input_ids, position_ids) = gpt_model_input position_ids,
args = [input_ids, position_ids, attention_mask] attention_mask,
kwargs['tokentype_ids'] = tokentype_ids layer_past=layer_past,
else: get_key_value=get_key_value)
args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage(): if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
...@@ -113,7 +119,7 @@ class GPTModelBase(MegatronModule): ...@@ -113,7 +119,7 @@ class GPTModelBase(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -122,79 +128,9 @@ class GPTModelBase(MegatronModule): ...@@ -122,79 +128,9 @@ class GPTModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Load word_embeddings. # Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key] state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPTModel, self).forward(
(input_ids, position_ids),
attention_mask,
labels=labels,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False):
return super(GPTModelIntermediateStage, self).forward(
hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPTModelLastStage, self).forward(
hidden_state,
attention_mask,
labels=labels,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
...@@ -21,6 +21,7 @@ import torch.nn.functional as F ...@@ -21,6 +21,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
...@@ -42,8 +43,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -42,8 +43,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None): encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -51,27 +55,21 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -51,27 +55,21 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None: if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, scaled_init_method] language_model = TransformerLanguageModel(
kwargs = {} init_method,
cls = None scaled_init_method,
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): encoder_attn_mask_type,
cls = TransformerLanguageModel num_tokentypes=num_tokentypes,
kwargs['num_tokentypes'] = num_tokentypes add_decoder=add_decoder,
kwargs['add_pooler'] = add_pooler decoder_attn_mask_type=decoder_attn_mask_type,
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): add_pooler=add_pooler,
cls = TransformerLanguageModelFirstStage pre_process=pre_process,
kwargs['num_tokentypes'] = num_tokentypes post_process=post_process
elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): )
cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler
else:
cls = TransformerLanguageModelIntermediateStage
# Language model.
language_model = cls(*args, **kwargs)
# key used for checkpoints. # key used for checkpoints.
language_model_key = 'language_model' language_model_key = 'language_model'
...@@ -257,17 +255,11 @@ class Embedding(MegatronModule): ...@@ -257,17 +255,11 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModelBase(MegatronModule): class TransformerLanguageModel(MegatronModule):
"""Transformer language model. """Transformer language model.
Arguments: Arguments:
transformer_hparams: transformer hyperparameters transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This max_sequence_length: maximum size of sequence. This
is used for positional embedding is used for positional embedding
...@@ -277,21 +269,30 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -277,21 +269,30 @@ class TransformerLanguageModelBase(MegatronModule):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_pooler=False): add_decoder=False,
super(TransformerLanguageModelBase, self).__init__() decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args() args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if self.pre_process:
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size, args.padded_vocab_size,
args.max_position_embeddings, args.max_position_embeddings,
...@@ -301,57 +302,109 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -301,57 +302,109 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.transformer = ParallelTransformer( self.encoder = ParallelTransformer(
attention_mask_func, self.init_method, self.init_method,
output_layer_init_method) output_layer_init_method,
self._transformer_key = 'transformer' self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
# Pooler. post_process=self.post_process
if mpu.is_pipeline_last_stage() and self.add_pooler: )
self.pooler = Pooler(self.hidden_size, self.init_method) self._encoder_key = 'encoder'
self._pooler_key = 'pooler'
# Decoder
def forward(self, language_model_input, attention_mask, if self.add_decoder:
tokentype_ids=None, layer_past=None, get_key_value=False, assert args.pipeline_model_parallel_size == 1, \
pooling_sequence_index=0): 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if self.pre_process:
(input_ids, position_ids) = language_model_input embedding_output = self.embedding(enc_input_ids, enc_position_ids,
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
transformer_input = embedding_output encoder_input = embedding_output
else: else:
transformer_input = language_model_input encoder_input = None
# Transformer. # encoder.
transformer_output = self.transformer(transformer_input, if enc_hidden_states is None:
attention_mask, encoder_output = self.encoder(encoder_input,
layer_past=layer_past, enc_attn_mask,
get_key_value=get_key_value) layer_past=layer_past,
get_key_value=get_key_value)
if mpu.is_pipeline_last_stage() and self.add_pooler: else:
pooled_output = self.pooler(transformer_output, encoder_output = enc_hidden_states.to(encoder_input.dtype)
pooling_sequence_index)
return transformer_output, pooled_output if self.post_process:
if self.add_pooler:
return transformer_output pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
if mpu.is_pipeline_first_stage(): if self.pre_process:
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._transformer_key] \ state_dict_[self._encoder_key] \
= self.transformer.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler: if self.post_process:
state_dict_[self._pooler_key] \ if self.add_pooler:
= self.pooler.state_dict_for_save_checkpoint( state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -360,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -360,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Embedding. # Embedding.
if mpu.is_pipeline_first_stage(): if self.pre_process:
if self._embedding_key in state_dict: if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key] state_dict_ = state_dict[self._embedding_key]
else: else:
...@@ -371,130 +424,41 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -371,130 +424,41 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key] state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer. # Encoder.
if self._transformer_key in state_dict: if self._encoder_key in state_dict:
state_dict_ = state_dict[self._transformer_key] state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else: else:
# for backward compatibility. # for backward compatibility.
state_dict_ = {} state_dict_ = {}
for key in state_dict.keys(): for key in state_dict.keys():
if 'transformer.' in key: if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key] state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler. # for backward compatibility.
if mpu.is_pipeline_last_stage() and self.add_pooler: state_dict_self_attention = {}
assert 'pooler' in state_dict, \ for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if self.post_process:
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
)
class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(TransformerLanguageModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
return super(TransformerLanguageModelIntermediateStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
)
...@@ -25,6 +25,13 @@ from megatron import mpu ...@@ -25,6 +25,13 @@ from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
...@@ -44,9 +51,9 @@ class MegatronModule(torch.nn.Module): ...@@ -44,9 +51,9 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage(ignore_virtual=True):
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
...@@ -60,6 +67,13 @@ class MegatronModule(torch.nn.Module): ...@@ -60,6 +67,13 @@ class MegatronModule(torch.nn.Module):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
...@@ -73,22 +87,28 @@ class MegatronModule(torch.nn.Module): ...@@ -73,22 +87,28 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using # stage's weights using all_reduce below.
# all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size,
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if torch.distributed.is_initialized():
torch.distributed.all_reduce(self.word_embeddings_weight().data, if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
group=mpu.get_embedding_group()) torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
...@@ -102,44 +122,56 @@ def conversion_helper(val, conversion): ...@@ -102,44 +122,56 @@ def conversion_helper(val, conversion):
return rtn return rtn
def fp32_to_fp16(val): def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16""" """Convert fp32 `val` to fp16/bf16"""
def half_conversion(val): def half_conversion(val):
val_typecheck = val val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)): if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES): if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half() val = float16_convertor(val)
return val return val
return conversion_helper(val, half_conversion) return conversion_helper(val, half_conversion)
def fp16_to_fp32(val): def float16_to_fp32(val):
"""Convert fp16 `val` to fp32""" """Convert fp16/bf16 `val` to fp32"""
def float_conversion(val): def float_conversion(val):
val_typecheck = val val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)): if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES): if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float() val = val.float()
return val return val
return conversion_helper(val, float_conversion) return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule): class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
if args.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('should not be here')
def __init__(self, module): self.float16_convertor = float16_convertor
super(FP16Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs) inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs) outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs) outputs = float16_to_fp32(outputs)
return outputs return outputs
......
...@@ -19,7 +19,8 @@ import torch ...@@ -19,7 +19,8 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -27,29 +28,40 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -27,29 +28,40 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
class MultipleChoiceBase(MegatronModule): class MultipleChoice(MegatronModule):
def __init__(self, num_tokentypes=2): def __init__(self,
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False) num_tokentypes=2,
pre_process=True,
post_process=True):
super(MultipleChoice, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.pre_process = pre_process
self.post_process = post_process
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage(): if self.post_process:
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1, self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
...@@ -63,22 +75,21 @@ class MultipleChoiceBase(MegatronModule): ...@@ -63,22 +75,21 @@ class MultipleChoiceBase(MegatronModule):
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1))
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
kwargs = {} input_ids = model_input
if mpu.is_pipeline_first_stage(): # Do the same as attention_mask for input_ids, tokentype_ids
input_ids = model_input assert len(input_ids.shape) == 3
# Do the same as attention_mask for input_ids, tokentype_ids assert len(tokentype_ids.shape) == 3
assert len(input_ids.shape) == 3 input_ids = input_ids.view(-1, input_ids.size(-1))
assert len(tokentype_ids.shape) == 3 tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = bert_position_ids(input_ids)
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
lm_output = self.language_model(
position_ids = bert_position_ids(input_ids) input_ids,
args = [input_ids, position_ids, extended_attention_mask] position_ids,
kwargs['tokentype_ids'] = tokentype_ids extended_attention_mask,
else: tokentype_ids=tokentype_ids
args = [model_input, extended_attention_mask] )
lm_output = self.language_model(*args, **kwargs) if self.post_process:
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output _, pooled_output = lm_output
multichoice_output = self.multichoice_dropout(pooled_output) multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output) multichoice_logits = self.multichoice_head(multichoice_output)
...@@ -98,7 +109,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -98,7 +109,7 @@ class MultipleChoiceBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._multichoice_head_key] \ state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict( = self.multichoice_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -109,7 +120,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -109,7 +120,7 @@ class MultipleChoiceBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self._multichoice_head_key in state_dict: if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict( self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict) state_dict[self._multichoice_head_key], strict=strict)
...@@ -117,54 +128,3 @@ class MultipleChoiceBase(MegatronModule): ...@@ -117,54 +128,3 @@ class MultipleChoiceBase(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._multichoice_head_key)) self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoice, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceFirstStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoiceFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceIntermediateStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceIntermediateStage, self).forward(
hidden_state,
attention_mask)
class MultipleChoiceLastStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceLastStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -6,11 +6,12 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi ...@@ -6,11 +6,12 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel from megatron.model import BertModel
from .module import MegatronModule from .module import MegatronModule
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False): def general_ict_model_provider(only_query_model=False, only_block_model=False):
...@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule): ...@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Transformer.""" """Transformer."""
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -22,11 +21,11 @@ import torch.nn.functional as F ...@@ -22,11 +21,11 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
...@@ -47,12 +46,6 @@ torch._C._jit_override_can_fuse_on_gpu(True) ...@@ -47,12 +46,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments: tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
...@@ -71,7 +64,7 @@ class ParallelMLP(MegatronModule): ...@@ -71,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
4 * args.hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True) skip_bias_add=True)
...@@ -85,12 +78,12 @@ class ParallelMLP(MegatronModule): ...@@ -85,12 +78,12 @@ class ParallelMLP(MegatronModule):
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -109,41 +102,61 @@ class ParallelMLP(MegatronModule): ...@@ -109,41 +102,61 @@ class ParallelMLP(MegatronModule):
return output, output_bias return output, output_bias
class ParallelSelfAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method,
output_layer_init_method, layer_number): output_layer_init_method, layer_number,
super(ParallelSelfAttention, self).__init__() attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.bf16 = args.bf16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size, self.hidden_size_per_partition = mpu.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear( if attention_type == AttnType.self_attn:
args.hidden_size, self.query_key_value = mpu.ColumnParallelLinear(
3 * args.hidden_size, args.hidden_size,
gather_output=False, 3 * projection_size,
init_method=init_method) gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -152,10 +165,10 @@ class ParallelSelfAttention(MegatronModule): ...@@ -152,10 +165,10 @@ class ParallelSelfAttention(MegatronModule):
self.norm_factor *= coeff self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.fp16, self.bf16,
args.scaled_upper_triang_masked_softmax_fusion, self.attn_mask_type,
args.scaled_masked_softmax_fusion, args.masked_softmax_fusion,
self.attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -166,72 +179,55 @@ class ParallelSelfAttention(MegatronModule): ...@@ -166,72 +179,55 @@ class ParallelSelfAttention(MegatronModule):
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
args.hidden_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): def forward(self, hidden_states, attention_mask, layer_past=None,
input_shape = mixed_layer.size(); get_key_value=False, encoder_output=None):
if num_splits_first: # hidden_states: [sq, b, h]
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\ # =====================
(num_splits, self.num_attention_heads_per_partition, # Query, Key, and Value
self.hidden_size_per_attention_head) # =====================
mixed_layer = mixed_layer.view(*intermediate_shape) if self.attention_type == AttnType.self_attn:
mixed_layer = mixed_layer.transpose(-2, -3).contiguous() # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
else: mixed_x_layer, _ = self.query_key_value(hidden_states)
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits) 3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
mixed_layer = mixed_layer.view(*intermediate_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
mixed_layer = mixed_layer.transpose(-1, -2).contiguous() (query_layer,
mixed_layer = mixed_layer.view(*input_shape) key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
return mixed_layer else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
def forward(self, hidden_states, attention_mask, layer_past=None, # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
get_key_value=False): new_tensor_shape = mixed_kv_layer.size()[:-1] + \
# hidden_states: [sq, b, h] (self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# ===================== # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
# Query, Key, and Value (key_layer,
# ===================== value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention head [sq, b, h] --> [sq, b, hp]
mixed_x_layer, _ = self.query_key_value(hidden_states) query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
checkpoint_version = get_checkpoint_version() new_tensor_shape = query_layer.size()[:-1] + \
if checkpoint_version is not None: (self.num_attention_heads_per_partition,
if checkpoint_version == 0: self.hidden_size_per_attention_head)
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] query_layer = query_layer.view(*new_tensor_shape)
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
...@@ -246,41 +242,41 @@ class ParallelSelfAttention(MegatronModule): ...@@ -246,41 +242,41 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value: if get_key_value:
present = (key_layer, value_layer) present = (key_layer, value_layer)
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
# =================================== # ===================================
# [b, np, sq, sk] # [b, np, sq, sk]
output_size = (query_layer.size(1), output_size = (query_layer.size(1),
query_layer.size(2), query_layer.size(2),
query_layer.size(0), query_layer.size(0),
key_layer.size(0)) key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk] # preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty( matmul_result = torch.empty(
output_size[0]*output_size[1], output_size[0]*output_size[1],
output_size[2], output_size[2],
output_size[3], output_size[3],
dtype=query_layer.dtype, dtype=query_layer.dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result, matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ================================================== # ==================================================
# Update attention mask for inference. [b, np, sq, sk] # Update attention mask for inference. [b, np, sq, sk]
# ================================================== # ==================================================
...@@ -298,7 +294,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -298,7 +294,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3), :attention_scores.size(3),
:attention_scores.size(3)] :attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
# =========================== # ===========================
...@@ -312,7 +307,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -312,7 +307,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
# ========================= # =========================
# Context layer. [sq, b, hp] # Context layer. [sq, b, hp]
# ========================= # =========================
...@@ -321,21 +315,21 @@ class ParallelSelfAttention(MegatronModule): ...@@ -321,21 +315,21 @@ class ParallelSelfAttention(MegatronModule):
# [sk, b, np, hn] --> [b, np, sq, hn] # [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn] # context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), output_size = (value_layer.size(1),
value_layer.size(2), value_layer.size(2),
query_layer.size(0), query_layer.size(0),
value_layer.size(3)) value_layer.size(3))
# change view [sk, b * np, hn] # change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk] # change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1) output_size[2], -1)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
...@@ -348,7 +342,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -348,7 +342,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
...@@ -361,7 +354,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -361,7 +354,7 @@ class ParallelSelfAttention(MegatronModule):
return output, bias return output, bias
def bias_dropout_add(x, bias, residual, prob, training) : def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training) out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out out = residual + out
...@@ -375,13 +368,13 @@ def get_bias_dropout_add(training): ...@@ -375,13 +368,13 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) : def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) : def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -389,66 +382,85 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) : ...@@ -389,66 +382,85 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an Transformer layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method, output_layer_init_method,
output_layer_init_method, layer_number): layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
# Self attention. # Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method, self.self_attention = ParallelAttention(
output_layer_init_method, init_method,
layer_number) output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data. # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP # MLP
self.mlp = ParallelMLP(init_method, self.mlp = ParallelMLP(init_method,
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False): encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.attention(layernorm_output, self.self_attention(layernorm_output,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
if get_key_value: if get_key_value:
attention_output, presents = attention_output attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = layernorm_output residual = layernorm_output
else: else:
residual = hidden_states residual = hidden_states
# jit scripting for a nn.module (with dropout) is not # jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two # trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying # different nn.functional routines to account for varying
# dropout semantics during training and inference phases. # dropout semantics during training and inference phases.
if self.bias_dropout_fusion: if self.bias_dropout_fusion:
...@@ -459,7 +471,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -459,7 +471,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
...@@ -470,16 +482,38 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -470,16 +482,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection. # Second residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = layernorm_output residual = layernorm_output
else: else:
residual = layernorm_input residual = layernorm_input
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
...@@ -496,12 +530,18 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -496,12 +530,18 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, def __init__(self, init_method, output_layer_init_method,
init_method, output_layer_init_method): layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
...@@ -515,15 +555,38 @@ class ParallelTransformer(MegatronModule): ...@@ -515,15 +555,38 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, init_method,
output_layer_init_method, layer_number) output_layer_init_method,
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
if mpu.is_pipeline_last_stage(): if self.post_process:
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -531,14 +594,18 @@ class ParallelTransformer(MegatronModule): ...@@ -531,14 +594,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
x_ = inputs[0] x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end): for index in range(start, end):
layer = self._get_layer(index) layer = self._get_layer(index)
x_ = layer(x_, inputs[1]) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_ return x_
return custom_forward return custom_forward
...@@ -548,13 +615,23 @@ class ParallelTransformer(MegatronModule): ...@@ -548,13 +615,23 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers l += self.checkpoint_num_layers
return hidden_states return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks. # Checks.
if layer_past is not None: if layer_past is not None:
...@@ -566,7 +643,7 @@ class ParallelTransformer(MegatronModule): ...@@ -566,7 +643,7 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
if mpu.is_pipeline_first_stage(): if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float. # If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection: if self.fp32_residual_connection:
...@@ -574,10 +651,18 @@ class ParallelTransformer(MegatronModule): ...@@ -574,10 +651,18 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is. # Otherwise, leave it as is.
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask) attention_mask,
encoder_output,
enc_dec_attn_mask)
else: else:
if get_key_value: if get_key_value:
presents = [] presents = []
...@@ -588,14 +673,16 @@ class ParallelTransformer(MegatronModule): ...@@ -588,14 +673,16 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index] past = layer_past[index]
hidden_states = layer(hidden_states, hidden_states = layer(hidden_states,
attention_mask, attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past, layer_past=past,
get_key_value=get_key_value) get_key_value=get_key_value)
if get_key_value: if get_key_value:
hidden_states, present = hidden_states hidden_states, present = hidden_states
presents.append(present) presents.append(present)
# Final layer norm. # Final layer norm.
if mpu.is_pipeline_last_stage(): if self.post_process:
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
......
...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers): ...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_ return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method): def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization.""" """Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns) layer = torch.nn.Linear(rows, columns)
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from .module import MegatronModule
class VitMlpHead(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :]
x = self.dense_in(x)
x = torch.tanh(x)
x = self.dense_out(x)
return x
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
hidden_size = args.hidden_size
key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param
class VitModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim
self.finetune = finetune
assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method
)
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert x.dtype == torch.half
x = self.linear_encoder(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
if not self.finetune:
x = self.mlp_head(x)
else:
x = self.class_head(x[:, 0, :])
return x
...@@ -38,13 +38,15 @@ from .initialize import get_pipeline_model_parallel_next_rank ...@@ -38,13 +38,15 @@ from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
...@@ -57,6 +59,8 @@ from .random import get_cuda_rng_tracker ...@@ -57,6 +59,8 @@ from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank ...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype): def _check_data_types(keys, data, target_dtype):
......
...@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None ...@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -48,7 +51,8 @@ def is_unitialized(): ...@@ -48,7 +51,8 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1): pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank(): ...@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if get_virtual_pipeline_model_parallel_world_size() is not None and \
get_virtual_pipeline_model_parallel_rank() != 0:
return False
return get_pipeline_model_parallel_rank() == 0 return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(): def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise.""" """Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = \
get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
return False
return get_pipeline_model_parallel_rank() == ( return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
...@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank(): ...@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size = get_tensor_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0] return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank(): def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
...@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
...@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank(): ...@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group()) return torch.distributed.get_world_size(group=get_data_parallel_group())
......
...@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, ...@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_stride': 1} 'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set. # Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
...@@ -260,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -260,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
...@@ -14,35 +14,35 @@ ...@@ -14,35 +14,35 @@
# limitations under the License. # limitations under the License.
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm from megatron.model import LayerNorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module): def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module in modules:
if isinstance(module_, LayerNorm): for module_ in module.modules():
no_weight_decay_params['params'].extend( if isinstance(module_, LayerNorm):
[p for p in list(module_._parameters.values()) no_weight_decay_params['params'].extend(
if p is not None]) [p for p in list(module_._parameters.values())
else: if p is not None])
weight_decay_params['params'].extend( else:
[p for n, p in list(module_._parameters.items()) weight_decay_params['params'].extend(
if p is not None and n != 'bias']) [p for n, p in list(module_._parameters.items())
no_weight_decay_params['params'].extend( if p is not None and n != 'bias'])
[p for n, p in list(module_._parameters.items()) no_weight_decay_params['params'].extend(
if p is not None and n == 'bias']) [p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
...@@ -52,28 +52,58 @@ def get_megatron_optimizer(model): ...@@ -52,28 +52,58 @@ def get_megatron_optimizer(model):
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = Adam(param_groups, if args.optimizer == 'adam':
lr=args.lr, optimizer = Adam(param_groups,
weight_decay=args.weight_decay, lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.weight_decay,
eps=args.adam_eps) betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
elif args.optimizer == 'sgd':
optimizer = SGD(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
# Determine whether the params have main-grad field.
params_have_main_grad = False
if args.DDP_impl == 'local':
params_have_main_grad = True
if args.fp16: if args.fp16 or args.bf16:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale. # Constant loss scale.
if args.loss_scale: if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale) grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale. # Dynamic loss scale.
else: else:
grad_scaler = DynamicGradScaler( if args.fp16:
initial_scale=args.initial_loss_scale, grad_scaler = DynamicGradScaler(
min_scale=args.min_loss_scale, initial_scale=args.initial_loss_scale,
growth_factor=2.0, min_scale=args.min_loss_scale,
backoff_factor=0.5, growth_factor=2.0,
growth_interval=args.loss_scale_window, backoff_factor=0.5,
hysteresis=args.hysteresis) growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer. # Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler, return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad) args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.bf16,
grad_scaler)
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad) return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
...@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import mpu from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grads_for_norm = [] grads_for_norm = []
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param.tensor_model_parallel or \ is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
(mpu.get_tensor_model_parallel_rank() == 0)
grad = param.grad.detach() grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
...@@ -117,3 +118,31 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -117,3 +118,31 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
clip_coeff) clip_coeff)
return total_norm return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_num_zeros = total_num_zeros.item()
return total_num_zeros
...@@ -27,7 +27,7 @@ from megatron import get_timers ...@@ -27,7 +27,7 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
...@@ -46,49 +46,77 @@ def _zero_grad_group_helper(group, set_to_none): ...@@ -46,49 +46,77 @@ def _zero_grad_group_helper(group, set_to_none):
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another.""" """Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if overflow_buf: if overflow_buf:
overflow_buf.fill_(0) overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
else: else:
overflow_buf = torch.cuda.IntTensor([0]) for this_, that_ in zip(this, that):
# Scaling with factor `1.0` is equivalent to copy. that_.copy_(this_)
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
class MegatronOptimizer(ABC): class MegatronOptimizer(ABC):
def __init__(self, optimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
"""Input optimizer is the base optimizer for example Adam.""" """Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.' assert self.optimizer, 'no optimizer is provided.'
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
def clip_grad_norm(self, clip_grad):
def get_parameters(self):
params = [] params = []
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
for param in param_group['params']: for param in param_group['params']:
params.append(param) params.append(param)
clip_grad_norm_fp32(params, clip_grad) return params
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
@abstractmethod @abstractmethod
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
pass pass
@abstractmethod @abstractmethod
def get_loss_scale(self): def get_loss_scale(self):
"""The output should be a cuda tensor of size 1.""" """The output should be a cuda tensor of size 1."""
pass pass
def scale_loss(self, loss): def scale_loss(self, loss):
"""Simple scaling.""" """Simple scaling."""
return self.get_loss_scale() * loss return self.get_loss_scale() * loss
@abstractmethod @abstractmethod
def step(self): def step(self):
pass pass
@abstractmethod @abstractmethod
def reload_model_params(self): def reload_model_params(self):
"""Refreshes any internal state from the current model parameters. """Refreshes any internal state from the current model parameters.
...@@ -98,14 +126,17 @@ class MegatronOptimizer(ABC): ...@@ -98,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated.""" with main parameters, the main parameters need to also be updated."""
pass pass
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
pass pass
@abstractmethod @abstractmethod
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
pass pass
# Promote state so it can be retrieved or set via # Promote state so it can be retrieved or set via
# "optimizer_instance.state" # "optimizer_instance.state"
def _get_state(self): def _get_state(self):
...@@ -116,6 +147,7 @@ class MegatronOptimizer(ABC): ...@@ -116,6 +147,7 @@ class MegatronOptimizer(ABC):
state = property(_get_state, _set_state) state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via # Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups" # "optimizer_instance.param_groups"
# (for example, to adjust the learning rate) # (for example, to adjust the learning rate)
...@@ -129,49 +161,90 @@ class MegatronOptimizer(ABC): ...@@ -129,49 +161,90 @@ class MegatronOptimizer(ABC):
class FP16OptimizerWithFP16Params(MegatronOptimizer): class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
def __init__(self, optimizer, grad_scaler, clip_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer) Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
self.clip_grad = clip_grad # None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend. # Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan. # Any non-zero value indicates inf/nan.
self.found_inf = torch.cuda.FloatTensor([0.0]) # Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor. # Dummy tensor needed for apex multi-apply tensor.
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) # For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0])
# ====================== # ======================
# main parameter stuff # main parameter stuff
# ====================== # ======================
# Three groups of parameters: # Three groups of parameters:
# fp16_groups: original fp16 parameters # float16_groups: original float16 parameters
# fp32_from_fp16_groups: fp32 copy of fp16 parameters # fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters # fp32_from_fp32_groups: original fp32 parameters
self.fp16_groups = [] self.float16_groups = []
self.fp32_from_fp16_groups = [] self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = [] self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer: # For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
fp16_params_this_group = [] float16_params_this_group = []
fp32_params_this_group = [] fp32_params_this_group = []
fp32_from_fp16_params_this_group = [] fp32_from_float16_params_this_group = []
# For all the parameters in this group: # For all the parameters in this group:
for i, param in enumerate(param_group['params']): for i, param in enumerate(param_group['params']):
if param.requires_grad: if param.requires_grad:
# fp16 params: # float16 params:
if param.type() == 'torch.cuda.HalfTensor': if param.type() in ['torch.cuda.HalfTensor',
fp16_params_this_group.append(param) 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy # Create a copy
main_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Store grads
main_param.requires_grad = True
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, mpu.copy_tensor_model_parallel_attributes(main_param,
param) param)
...@@ -179,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -179,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param) fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
...@@ -191,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -191,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group['params'][i] = param param_group['params'][i] = param
else: else:
raise TypeError("Wrapped parameters must be either " raise TypeError('Wrapped parameters must be one of '
"torch.cuda.FloatTensor or " 'torch.cuda.FloatTensor, '
"torch.cuda.HalfTensor. " 'torch.cuda.HalfTensor, or '
"Received {}".format(param.type())) 'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
...@@ -207,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -207,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
fp16_groups & fp32_from_fp32_groups.""" float16_groups & fp32_from_fp32_groups."""
for group in self.fp16_groups: for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups: for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self): def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the fp16 group. # This only needs to be done for the float16 group.
model_grads = [] for model_group, main_group in zip(self.float16_groups,
main_grads = [] self.fp32_from_float16_groups):
for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if model_param.grad is not None: if self.params_have_main_grad:
if main_param.grad is None: main_param.grad = model_param.main_grad.float()
main_param.grad = torch.empty_like(main_param) else:
model_grads.append(model_param.grad.data) if model_param.grad is not None:
main_grads.append(main_param.grad.data) main_param.grad = model_param.grad.float()
_multi_tensor_copy_this_to_that(this=model_grads, that=main_grads, # For fp32 grads, we need to reset the grads to main grad.
overflow_buf=self._dummy_overflow_buf) if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
main_grads = [] main_grads = []
# fp32 params fromm fp16 ones. # fp32 params fromm float16 ones.
for main_group in self.fp32_from_fp16_groups: for main_group in self.fp32_from_float16_groups:
for main_param in main_group: for main_param in main_group:
if main_param.grad is not None: if main_param.grad is not None:
main_grads.append(main_param.grad.data) main_grads.append(main_param.grad.data)
...@@ -261,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -261,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return found_inf_flag return found_inf_flag
def _get_model_and_main_params_data_fp16(self): def _get_model_and_main_params_data_float16(self):
model_data = [] model_data = []
main_data = [] main_data = []
for model_group, main_group in zip(self.fp16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_fp16_groups): self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data) model_data.append(model_param.data)
main_data.append(main_param.data) main_data.append(main_param.data)
...@@ -273,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -273,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
# Only needed for the fp16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
# Only needed for the fp16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data, _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
...@@ -300,24 +378,34 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -300,24 +378,34 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self._copy_model_grads_to_main_grads() self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# Unscale and check for inf/nan. # Do unscale, check for inf, and update grad scaler only for
timers('optimizer-unscale-and-check-inf').start() # the case that grad scaler is provided.
found_inf_flag = self._unscale_main_grads_and_check_for_nan() if self.grad_scaler:
timers('optimizer-unscale-and-check-inf').stop()
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients # We are done with scaling gradients
# so we can update the loss scale. # so we can update the loss scale.
self.grad_scaler.update(found_inf_flag) self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update. # If we found inf/nan, skip the update.
if found_inf_flag: if found_inf_flag:
return False return False, None, None
# Clip the main gradients. # Clip the main gradients.
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad').start()
self.clip_grad_norm(self.clip_grad) grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer. # Step the optimizer.
self.optimizer.step() self.optimizer.step()
...@@ -327,14 +415,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -327,14 +415,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# Successful update. # Successful update.
return True return True, grad_norm, num_zeros_in_grad
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
state_dict['grad_scaler'] = self.grad_scaler.state_dict() if self.grad_scaler:
state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
return state_dict return state_dict
...@@ -352,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -352,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0('***WARNING*** found an old checkpoint, will not ' print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...') 'load grad scaler ...')
else: else:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params. # Copy data for the main params.
fp32_from_fp16_params_key = 'fp32_from_fp16_params' fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_fp16_params_key not in state_dict: if fp32_from_float16_params_key not in state_dict:
fp32_from_fp16_params_key = 'fp32_from_fp16' fp32_from_float16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip( for current_group, saved_group in zip(
self.fp32_from_fp16_groups, self.fp32_from_float16_groups,
state_dict[fp32_from_fp16_params_key]): state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group): for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
...@@ -368,10 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -368,10 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class FP32Optimizer(MegatronOptimizer): class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad): def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
super(FP32Optimizer, self).__init__(optimizer)
self.clip_grad = clip_grad
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
...@@ -391,15 +489,26 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -391,15 +489,26 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer. """Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow.""" Always return successful since there is no overflow."""
# Copy main_grads to grads.
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad = param.main_grad
# Clip gradients. # Clip gradients.
grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Update parameters. # Update parameters.
self.optimizer.step() self.optimizer.step()
# No overflow for FP32 optimizer. # No overflow for FP32 optimizer.
return True return True, grad_norm, num_zeros_in_grad
def reload_model_params(self): def reload_model_params(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment