Commit c9a59554 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

Merge branch 'main' into lmcafee/distrib-opt-nodupe

merge sequence parallelism's layernorm all-reduce into distributed optimizer.
parents 9288b6b7 877f6a1b
Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision.
Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel ([tensor](https://arxiv.org/pdf/1909.08053.pdf), [sequence](https://arxiv.org/pdf/2205.05198), and [pipeline](https://arxiv.org/pdf/2104.04473.pdf)), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision.
Below are some of the projects where we have directly used Megatron:
* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf)
......@@ -8,19 +8,26 @@ Below are some of the projects where we have directly used Megatron:
* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150)
* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf)
* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html)
* [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf)
* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
* [Few-shot Instruction Prompts for Pretrained Language Models to Detect Social Biases](https://arxiv.org/abs/2112.07868)
* [Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173)
* [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](https://arxiv.org/abs/2201.11990)
* [Multi-Stage Prompting for Knowledgeable Dialogue Generation](https://arxiv.org/abs/2203.08745)
Megatron is also used in [NeMo Megatron](https://developer.nvidia.com/nvidia-nemo#nemo-megatron), a framework to help enterprises overcome the challenges of building and training sophisticated natural language processing models with billions and trillions of parameters.
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The graph below shows that we scale nearly linear up to 1 trillion parameter models running on 3072 GPUs. Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
Additionally, the model parallel size column reports a combined tensor and pipeline parallelism degrees. For numbers larger than 8, typically tensor parallel of size 8 was used. So, for example, the 145B model reports the total model parallel size of 64, which means that this setup used TP=8 and PP=8.
![Scaling Graph](images/Achieved_petaFLOPs.png)
![Cases](images/cases_april2021.png)
All the cases from 1 billion to 1 trillion parameters achieve more than 43% half precision utilization, which is high for an end-to-end application. We observe that initially the utilization remains constant but as hidden size increases for larger models, utilization starts increasing and reaches 52% for the largest model. We also note that achieved aggregate petaFLOPs across all GPUs increases almost linearly with number of GPUs, demonstrating good weak scaling.
The following table shows both model (MFU) and hardware (HFU) FLOPs utilization for select configurations up to 1T parameters (see [our paper](https://arxiv.org/pdf/2205.05198) for a description of how these are calculated). As the model size increases, we achieve better GPU utilization and for the one trillion parameter model, we reach a MFU and HFU of 56.3% and 57.0%, respectively. Note that these numbers are also measured on benchmark runs and in this case are measured using a data parallel size of one. Data parallelism introduces some overhead due to the gradient all-reduce required between the data parallel groups. However, for large transformer models, this overhead is not large and can almost entirely eliminted by overlapping the gradient all-reduce with backpropagation.
| Model Size | Model FLOPs Utilization | Hardware FLOPs Utilization |
| :---: | :---: | :---: |
| 22B | 41.5% | 43.7% |
| 175B | 51.4% | 52.8% |
| 530B | 56.0% | 57.0% |
| 1T | 56.3% | 57.0% |
# Contents
* [Contents](#contents)
......@@ -257,7 +264,9 @@ The `examples/pretrain_{bert,gpt,t5}_distributed.sh` scripts use the PyTorch dis
We use two types of parallelism: data and model parallelism. We facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.
Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each).
Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use sequence parallelism specify `--sequence-parallel`, which requires tensor model parallel as it split among the same GPUs.
To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each).
<!-- The number of microbatches in a per-pipeline minibatch is controlled by the `--num-microbatches-in-minibatch` argument. With `WORLD_SIZE` GPUs, `TENSOR_MP_SIZE` tensor-model-parallel size, `PIPELINE_MP_SIZE` pipeline-model-parallel-size, `WORLD_SIZE`/(`TENSOR_MP_SIZE` * `PIPELINE_MP_SIZE`) GPUs will be used for data parallelism. The default values for `--tensor-model-parallel-size` and `--pipeline-model-parallel-size` is 1, which will not implement either form of model parallelism. -->
......@@ -291,6 +300,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_<model>.py \
--data-path $DATA_PATH \
--tensor-model-parallel-size $TENSOR_MP_SIZE \
--pipeline-model-parallel-size $PIPELINE_MP_SIZE \
--sequence-parallel \
--DDP-impl torch
</pre>
......@@ -298,11 +308,13 @@ The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper
## Activation Checkpointing and Recomputation
To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We use a Transformer layer as the unit of checkpointing because the activation size bloats in the middle of a Transformer layer so checkpointing the input of a Transformer layer is storage-efficient. We support two activation checkpointing methods: `uniform` and `block`.
To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We support two levels of recompute granularity: `selective` and `full`. Selective recomputation is the default and recommended in almost all cases. It saves the activations that take less space and are expensive to recompute and recomputes activations that take a lot of space but are relatively cheap to recompute (see [our paper](https://arxiv.org/pdf/2205.05198) for details). To enable selective activation recompute simply use `--recompute-activations`.
For cases where memory is very tight, `full` checkpointing saves just the inputs to a transformer layer, or a block of transformer layers, and recomputes everything else. To turn on full activation recompute use `--recompute-granularity full`. When using full activation recomputation, there are two methods: `uniform` and `block`, chosen using the `--recompute-method` argument.
Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed.
* Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed.
Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop.
* Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop.
## GPT-3 Example
......
......@@ -19,6 +19,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt.py \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--sequence-parallel \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......
......@@ -103,14 +103,20 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
'use --recompute-granularity and --recompute-method instead. '
'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
......@@ -284,19 +290,32 @@ def parse_args(extra_args_provider=None, defaults={},
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing.
if args.distribute_checkpointed_activations:
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distributed checkpoint activations to work you '\
'need to use a activation-checkpoint method '
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args)
return args
......@@ -477,27 +496,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute checkpointed activations '
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
'to recompute within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......@@ -546,6 +578,8 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
......
......@@ -295,14 +295,19 @@ class IndexedDatasetBuilder(object):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
doc_offset = len(self.sizes)
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
for data_offset in index.data_offsets[1:]:
self.data_offsets.append(begin + data_offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
......@@ -556,8 +561,9 @@ class MMapIndexedDatasetBuilder(object):
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
offset = len(self._sizes)
self._sizes.extend(index.sizes)
self._doc_idx.extend((offset + index.doc_idx)[1:])
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
......
......@@ -31,6 +31,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
def initialize_megatron(extra_args_provider=None, args_defaults={},
......@@ -64,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args()
if args.lazy_mpu_init:
args.use_cpu_initialization=True
......@@ -230,7 +229,7 @@ def write_args_to_tensorboard():
global_step=args.iteration)
def _set_jit_fusion_options():
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
......@@ -251,3 +250,51 @@ def _set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
_warmup_jit_function()
def _warmup_jit_function():
""" Compilie JIT functions before the main training steps """
args = get_args()
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
else:
dtype = torch.float32
# Warmup fused bias+gelu
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype, device='cuda')
input = torch.rand((args.seq_length, args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size),
dtype=dtype, device='cuda')
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_gelu(bias, input)
del bias, input, output
# Warmup fused bias+dropout+add
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
del bias, input, residual, output
torch.cuda.empty_cache()
......@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
......@@ -110,14 +115,20 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
# [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0,1).contiguous()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits
......
......@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :]
pooled_output = lm_output[0, :, :]
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
......
......@@ -69,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True):
def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda
......@@ -94,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self):
......
......@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy):
# Output.
# Output. Format [s b h]
output = parallel_lm_logits(
lm_output,
logit_weights,
parallel_output)
if labels is None:
return output
# [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else:
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss
......
......@@ -26,23 +26,29 @@ from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
args = get_args()
# Parallel logits.
if args.async_tensor_model_parallel_allreduce:
if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel:
input_parallel = input_
async_grad_allreduce = mpu.get_tensor_model_parallel_world_size() > 1
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncAllreduce.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce)
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, args.sequence_parallel)
# Gather if needed.
if parallel_output:
return logits_parallel
......@@ -98,12 +104,23 @@ class Pooler(MegatronModule):
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states,
to_model_parallel=False)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
......@@ -164,6 +181,8 @@ class Embedding(MegatronModule):
else:
self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
......@@ -205,8 +224,20 @@ class Embedding(MegatronModule):
else:
assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
embeddings = self.embedding_dropout(embeddings)
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
......
......@@ -152,19 +152,24 @@ class T5Model(MegatronModule):
if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output.
# Output. [s, b, h]
lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight())
if lm_labels is None:
return lm_logits
# [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous()
else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
......
......@@ -15,10 +15,11 @@
"""Transformer."""
import math
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers, get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
......@@ -27,6 +28,7 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
......@@ -129,21 +130,21 @@ class SwitchMLP(MegatronModule):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
# hidden_states: [s, b, h]
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
max_ind = max_ind.view(-1) # [s*b]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
......@@ -159,15 +160,156 @@ class SwitchMLP(MegatronModule):
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h)
return output_total, output_bias_total
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
......@@ -177,13 +319,6 @@ class ParallelAttention(MegatronModule):
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
......@@ -193,8 +328,6 @@ class ParallelAttention(MegatronModule):
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
......@@ -221,24 +354,9 @@ class ParallelAttention(MegatronModule):
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output.
self.dense = mpu.RowParallelLinear(
......@@ -248,6 +366,23 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method,
skip_bias_add=True)
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
hidden_states = mpu.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
......@@ -257,13 +392,11 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
......@@ -281,7 +414,6 @@ class ParallelAttention(MegatronModule):
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
......@@ -322,7 +454,6 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
......@@ -344,90 +475,16 @@ class ParallelAttention(MegatronModule):
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# ==================================
# core attention computation
# ==================================
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
# =================
# Output. [sq, b, h]
......@@ -470,7 +527,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
......@@ -494,7 +551,8 @@ class ParallelTransformerLayer(MegatronModule):
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -511,7 +569,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -523,7 +582,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
# MLP
if args.num_experts is not None:
......@@ -531,10 +591,17 @@ class ParallelTransformerLayer(MegatronModule):
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
......@@ -564,8 +631,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
......@@ -591,8 +657,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
......@@ -612,8 +677,7 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
......@@ -660,22 +724,30 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__()
args = get_args()
self.layer_type = layer_type
self.model_type = args.model_type
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = mpu.get_num_layers(
......@@ -739,12 +811,13 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......@@ -764,32 +837,33 @@ class ParallelTransformer(MegatronModule):
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
l += self.recompute_num_layers
elif self.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
raise ValueError("Invalid activation recompute method.")
return hidden_states
......@@ -806,21 +880,14 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Checks.
if inference_params:
assert self.activations_checkpoint_method is None, \
assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing'
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
......@@ -841,37 +908,34 @@ class ParallelTransformer(MegatronModule):
# is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
requires_grad=True,
keep_graph=True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork()
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
rng_context = nullcontext()
with rng_context:
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return output
return hidden_states
......@@ -21,7 +21,6 @@ import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
......@@ -148,6 +147,7 @@ class VitBackbone(MegatronModule):
post_process=True,
class_token=True,
single_token_output=False,
post_layer_norm=True,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args()
......@@ -165,6 +165,7 @@ class VitBackbone(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.class_token = class_token
self.post_layer_norm = post_layer_norm
self.hidden_size = args.hidden_size
self.patch_dim = args.patch_dim
self.img_h = args.img_h
......@@ -218,6 +219,7 @@ class VitBackbone(MegatronModule):
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process,
post_layer_norm=self.post_layer_norm,
drop_path_rate=self.drop_path_rate
)
......
......@@ -49,18 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
......
......@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
......@@ -199,35 +200,80 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous all-reduce and gradient accumulation
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce):
async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t())
ctx.sequence_parallel = sequence_parallel
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel:
handle.wait()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
......@@ -235,15 +281,38 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
import fused_dense_cuda
fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None
class ColumnParallelLinear(torch.nn.Module):
......@@ -323,23 +392,28 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and
world_size > 1)
self.sequence_parallel = (
args.sequence_parallel and
world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \
not self.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce:
if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel:
input_parallel = input_
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce)
self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
......@@ -420,26 +494,34 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None)
self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
......@@ -32,13 +32,13 @@ def _reduce(input_):
return input_
def _split(input_):
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
if world_size == 1:
return input_
# Split along last dimension.
......@@ -51,12 +51,34 @@ def _split(input_):
return output
def _gather(input_):
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
if world_size == 1:
return input_
# Size and dimension.
......@@ -73,6 +95,44 @@ def _gather(input_):
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
......@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
......@@ -126,15 +186,69 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, to_model_parallel=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, to_model_parallel=True):
ctx.to_model_parallel = to_model_parallel
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
to_model_parallel = ctx.to_model_parallel
if to_model_parallel:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
......@@ -155,3 +269,16 @@ def scatter_to_tensor_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, to_model_parallel=True):
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
......@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_checkpointed_activations, *args):
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
ctx.distribute_saved_activations \
= distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_checkpointed_activations:
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
......@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations:
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
......@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads
def checkpoint(function, distribute_checkpointed_activations, *args):
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
distribute_saved_activations, *args)
......@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from megatron.model import LayerNorm
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
......
......@@ -514,7 +514,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
if param in self.optimizer.state:
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
......@@ -532,10 +531,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment