"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "5d194c5db0da8eeb591bfcf5e946b79e6efd6853"
Commit 67aa8619 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents 03d09af0 f5345dfa
__pycache__ __pycache__
# Distribution / packaging
build/
dist/
*.egg-info/
\ No newline at end of file
include megatron/data/Makefile
include megatron/data/helpers.cpp
...@@ -11,7 +11,7 @@ Below are some of the projects where we have directly used Megatron: ...@@ -11,7 +11,7 @@ Below are some of the projects where we have directly used Megatron:
* [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf) * [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf)
* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf) * [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging. Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
![Cases](images/cases_april2021.png) ![Cases](images/cases_april2021.png)
...@@ -48,13 +48,6 @@ We have tested Megatron with [NGC's PyTorch container](https://ngc.nvidia.com/ca ...@@ -48,13 +48,6 @@ We have tested Megatron with [NGC's PyTorch container](https://ngc.nvidia.com/ca
To use this repository, please install the latest supported versions of PyTorch with GPU support (python 3.8, pytorch 1.8, cuda 11.1, and nccl 2.8.3 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We strongly recommend using one of [NGC's recent PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) (the latest compatible version at time of publication can be pulled with `docker pull nvcr.io/nvidia/pytorch:20.12-py3`). Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. To use this repository, please install the latest supported versions of PyTorch with GPU support (python 3.8, pytorch 1.8, cuda 11.1, and nccl 2.8.3 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We strongly recommend using one of [NGC's recent PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) (the latest compatible version at time of publication can be pulled with `docker pull nvcr.io/nvidia/pytorch:20.12-py3`). Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks.
<!--
To use megatron you can either clone the repo or install it via pip (make sure python3-dev is installed):
<pre>
pip install megatron-lm
</pre>
-->
## Downloading Checkpoints ## Downloading Checkpoints
We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1). We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1).
...@@ -419,33 +412,23 @@ python tools/checkpoint_util.py \ ...@@ -419,33 +412,23 @@ python tools/checkpoint_util.py \
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation ## GPT Text Generation
`bash examples/generate_text.sh`
We generate text samples using largely the GPT pretraining script. Few changes need to make, such as we need to provide the path to the pretrained checkpoint, the length of the output samples, whether to generate texts unconditionally (`--num-samples` to denote how many samples to generate) or conditional (need to pass `--sample-input-file <filename>` where each line of the file will be used as the conditional texts). There are few optional parameters to play, e.g. `top-k`, `top-p`, or `greedy` (set top-k and top-p to 0) sampling.. We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
<pre> <pre>
CHECKPOINT_PATH=checkpoints/gpt2_345m tools/text_generation_cli.py localhost
VOCAB_FILE=gpt2-vocab.json </pre>
MERGE_FILE=gpt2-merges.txt
GPT_ARGS=&#60;same as those in <a href="#gpt-pretraining">GPT pretraining</a> above&#62;
MAX_OUTPUT_SEQUENCE_LENGTH=1024 You can also use CURL or any other tools to query the server directly:
TEMPERATURE=1.0
TOP_P=0.9
NUMBER_OF_SAMPLES=2
OUTPUT_FILE=samples.json
python tools/generate_samples_gpt.py \ <pre>
$GPT_ARGS \ curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["Hello world"], "tokens_to_generate":1}'
--load $CHECKPOINT_PATH \
--out-seq-length $MAX_OUTPUT_SEQUENCE_LENGTH \
--temperature $TEMPERATURE \
--genfile $OUTPUT_FILE \
--num-samples $NUMBER_OF_SAMPLES \
--top_p $TOP_P \
--recompute
</pre> </pre>
See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options.
## GPT Evaluation ## GPT Evaluation
We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy. We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.
......
#!/bin/bash
CHECKPOINT_PATH=checkpoints/gpt2_345m
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
python tools/generate_samples_gpt2.py \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load $CHECKPOINT_PATH \
--num-attention-heads 16 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--batch-size 2 \
--seq-length 1024 \
--out-seq-length 1024 \
--temperature 1.0 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--genfile unconditional_samples.json \
--num-samples 2 \
--top_p 0.9 \
--recompute
...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)> ...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
pip install flask-restful pip install flask-restful
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py / python -m torch.distributed.run $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 / --tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 / --pipeline-model-parallel-size 1 \
--num-layers 24 / --num-layers 24 \
--hidden-size 1024 / --hidden-size 1024 \
--load ${CHECKPOINT} / --load ${CHECKPOINT} \
--num-attention-heads 16 / --num-attention-heads 16 \
--max-position-embeddings 1024 / --max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer / --tokenizer-type GPT2BPETokenizer \
--fp16 / --fp16 \
--micro-batch-size 1 / --micro-batch-size 1 \
--seq-length 1024 / --seq-length 1024 \
--out-seq-length 1024 / --out-seq-length 1024 \
--temperature 1.0 / --temperature 1.0 \
--vocab-file $VOCAB_FILE / --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE / --merge-file $MERGE_FILE \
--top_p 0.9 / --top_p 0.9 \
--seed 42 --seed 42
...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)> ...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
pip install flask-restful pip install flask-restful
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py / python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 8 / --tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 / --pipeline-model-parallel-size 1 \
--num-layers 24 / --num-layers 24 \
--hidden-size 1024 / --hidden-size 1024 \
--load ${CHECKPOINT} / --load ${CHECKPOINT} \
--num-attention-heads 16 / --num-attention-heads 16 \
--max-position-embeddings 1024 / --max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer / --tokenizer-type GPT2BPETokenizer \
--fp16 / --fp16 \
--micro-batch-size 1 / --micro-batch-size 1 \
--seq-length 1024 / --seq-length 1024 \
--out-seq-length 1024 / --out-seq-length 1024 \
--temperature 1.0 / --temperature 1.0 \
--vocab-file $VOCAB_FILE / --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE / --merge-file $MERGE_FILE \
--top_p 0.9 / --top_p 0.9 \
--seed 42 --seed 42
...@@ -14,17 +14,6 @@ ...@@ -14,17 +14,6 @@
# limitations under the License. # limitations under the License.
import torch import torch
from .package_info import (
__description__,
__contact_names__,
__url__,
__download_url__,
__keywords__,
__license__,
__package_name__,
__version__,
)
from .global_vars import get_args from .global_vars import get_args
from .global_vars import get_current_global_batch_size from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches from .global_vars import get_num_microbatches
......
...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -85,6 +86,12 @@ def validate_args(args, defaults={}): ...@@ -85,6 +86,12 @@ def validate_args(args, defaults={}):
args.world_size, args.data_parallel_size, args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True) args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments # Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \ assert args.batch_size is None, '--batch-size argument is no longer ' \
...@@ -278,6 +285,18 @@ def _check_arg_is_not_none(args, arg): ...@@ -278,6 +285,18 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg) assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
return parser
def _add_network_size_args(parser): def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size') group = parser.add_argument_group(title='network size')
...@@ -467,6 +486,11 @@ def _add_training_args(parser): ...@@ -467,6 +486,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None, group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser return parser
...@@ -606,6 +630,9 @@ def _add_distributed_args(parser): ...@@ -606,6 +630,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "ATen/ATen.h" #include "ATen/ATen.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh> #include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -329,6 +329,7 @@ void cuApplyLayerNorm( ...@@ -329,6 +329,7 @@ void cuApplyLayerNorm(
mean[i1] = mu; mean[i1] = mu;
invvar[i1] = c_invvar; invvar[i1] = c_invvar;
} }
__syncthreads();
} }
} }
...@@ -644,6 +645,8 @@ void cuComputeGradInput( ...@@ -644,6 +645,8 @@ void cuComputeGradInput(
k_grad_input[l] = static_cast<T>(f_grad_input); k_grad_input[l] = static_cast<T>(f_grad_input);
} }
} }
// prevent race where buf is written again before reads are done
__syncthreads();
} }
} }
......
...@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
args.use_cpu_initialization=True args.use_cpu_initialization=True
...@@ -173,11 +176,11 @@ def _initialize_distributed(): ...@@ -173,11 +176,11 @@ def _initialize_distributed():
else: else:
args.local_rank = device args.local_rank = device
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Call the init process # Call the init process
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
timeout=timedelta(days=7)) timeout=timedelta(days=7))
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
...@@ -187,7 +190,8 @@ def _initialize_distributed(): ...@@ -187,7 +190,8 @@ def _initialize_distributed():
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size) args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
def _init_autoresume(): def _init_autoresume():
...@@ -222,3 +226,25 @@ def write_args_to_tensorboard(): ...@@ -222,3 +226,25 @@ def write_args_to_tensorboard():
writer.add_text(arg, str(getattr(args, arg)), writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration) global_step=args.iteration)
def _set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel ...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from .t5_model import T5Model from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
from .enums import ModelType
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
import enum import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
import torch import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
###### BIAS GELU FUSION/ NO AUTOGRAD ################ ###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
......
...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal ...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
if get_key_value:
lm_output, presents = lm_output
# Output. # Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if get_key_value:
output = [output, presents]
if labels is None: if labels is None:
return output return output
else: else:
...@@ -90,23 +82,19 @@ class GPTModel(MegatronModule): ...@@ -90,23 +82,19 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, inference_params=None):
forward_method_parallel_output=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
layer_past=layer_past, inference_params=inference_params)
get_key_value=get_key_value)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
get_key_value,
self.parallel_output, self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
return lm_output return lm_output
......
...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True): pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler, ...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method, scaled_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder, add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type, decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler, add_pooler=add_pooler,
...@@ -161,6 +163,16 @@ class Embedding(MegatronModule): ...@@ -161,6 +163,16 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -275,6 +287,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -275,6 +287,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_encoder=True,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False, add_pooler=False,
...@@ -288,10 +301,12 @@ class TransformerLanguageModel(MegatronModule): ...@@ -288,10 +301,12 @@ class TransformerLanguageModel(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
...@@ -304,25 +319,37 @@ class TransformerLanguageModel(MegatronModule): ...@@ -304,25 +319,37 @@ class TransformerLanguageModel(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.encoder = ParallelTransformer( # Encoder (usually set to True, False if part of an encoder-decoder
self.init_method, # architecture and in encoder-only stage).
output_layer_init_method, if self.add_encoder:
self_attn_mask_type=self.encoder_attn_mask_type, self.encoder = ParallelTransformer(
pre_process=self.pre_process, self.init_method,
post_process=self.post_process output_layer_init_method,
) self_attn_mask_type=self.encoder_attn_mask_type,
self._encoder_key = 'encoder' pre_process=self.pre_process,
post_process=self.post_process
# Decoder )
self._encoder_key = 'encoder'
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert args.pipeline_model_parallel_size == 1, \ assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder' 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
else:
self.decoder = None
if self.post_process: if self.post_process:
# Pooler. # Pooler.
...@@ -332,28 +359,55 @@ class TransformerLanguageModel(MegatronModule): ...@@ -332,28 +359,55 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None,
get_key_value=False, pooling_sequence_index=0, inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Encoder embedding.
if self.pre_process: if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
encoder_input = embedding_output
else: else:
encoder_input = None encoder_input = None
# encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input, if self.encoder is not None:
enc_attn_mask, encoder_output = self.encoder(
layer_past=layer_past, encoder_input,
get_key_value=get_key_value) enc_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder_hidden_state
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
...@@ -371,16 +425,20 @@ class TransformerLanguageModel(MegatronModule): ...@@ -371,16 +425,20 @@ class TransformerLanguageModel(MegatronModule):
else: else:
return encoder_output return encoder_output
# Decoder Embedding # Decoder embedding.
dec_embedding_output = self.embedding(dec_input_ids, if self.pre_process:
dec_position_ids) decoder_input = self.embedding(dec_input_ids,
# decoder dec_position_ids)
decoder_output = self.decoder(dec_embedding_output, else:
dec_attn_mask, decoder_input = None
layer_past=layer_past,
get_key_value=get_key_value, # Run decoder.
encoder_output=encoder_output, decoder_output = self.decoder(
enc_dec_attn_mask=enc_dec_attn_mask) decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
...@@ -396,9 +454,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -396,9 +454,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._encoder_key] \ if self.add_encoder:
= self.encoder.state_dict_for_save_checkpoint( state_dict_[self._encoder_key] \
destination, prefix, keep_vars) = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
...@@ -427,38 +486,39 @@ class TransformerLanguageModel(MegatronModule): ...@@ -427,38 +486,39 @@ class TransformerLanguageModel(MegatronModule):
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder. # Encoder.
if self._encoder_key in state_dict: if self.add_encoder:
state_dict_ = state_dict[self._encoder_key] if self._encoder_key in state_dict:
# for backward compatibility. state_dict_ = state_dict[self._encoder_key]
elif 'transformer' in state_dict: # For backward compatibility.
state_dict_ = state_dict['transformer'] elif 'transformer' in state_dict:
else: state_dict_ = state_dict['transformer']
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else: else:
state_dict_self_attention[key] = state_dict_[key] # For backward compatibility.
state_dict_ = state_dict_self_attention state_dict_ = {}
for key in state_dict.keys():
self.encoder.load_state_dict(state_dict_, strict=strict) if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process: if self.post_process:
# pooler
if self.add_pooler: if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
# decoder # Decoder.
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
......
...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module): ...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True): if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True): else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module): ...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage # This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline # when we are using pipeline parallelism. Nothing to do if we aren't
# parallelism there is nothing to do. # using pipeline parallelism.
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
return return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
# workers, so we do the following: # workers, so we do the following:
...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module): ...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else: else:
print("WARNING! Distributed processes aren't initialized, so " print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. " "word embeddings in the last layer are not initialized. "
...@@ -166,6 +187,10 @@ class Float16Module(MegatronModule): ...@@ -166,6 +187,10 @@ class Float16Module(MegatronModule):
self.float16_convertor = float16_convertor self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor) inputs = fp32_to_float16(inputs, self.float16_convertor)
......
...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule): ...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class T5Model(MegatronModule): class T5Model(MegatronModule):
"""T5 Language model.""" """T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(T5Model, self).__init__() super(T5Model, self).__init__()
args = get_args() args = get_args()
...@@ -95,19 +101,29 @@ class T5Model(MegatronModule): ...@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
add_decoder=True, add_encoder=add_encoder,
add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.lm_head = T5LMHead( self.initialize_word_embeddings(init_method_normal)
self.language_model.embedding.word_embeddings.weight.size(0),
parallel_output) if self.post_process and self.add_decoder:
self._lm_head_key = 'lm_head' self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0),
parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
...@@ -134,22 +150,28 @@ class T5Model(MegatronModule): ...@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states) enc_hidden_states=enc_hidden_states)
decoder_output, encoder_output = lm_output if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output. # Output.
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.language_model.embedding.word_embeddings.weight) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits, encoder_output return lm_logits
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), if self.fp16_lm_cross_entropy:
lm_labels) assert lm_logits.dtype == torch.half
return lm_loss, encoder_output lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
return decoder_output
else:
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -160,9 +182,14 @@ class T5Model(MegatronModule): ...@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \ if self.post_process and self.add_decoder:
= self.lm_head.state_dict_for_save_checkpoint( state_dict_[self._lm_head_key] \
destination, prefix, keep_vars) = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -170,5 +197,10 @@ class T5Model(MegatronModule): ...@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key], if self.post_process and self.add_decoder:
strict=strict) self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
# Load word embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
...@@ -21,17 +21,12 @@ import torch.nn.functional as F ...@@ -21,17 +21,12 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -123,6 +118,7 @@ class ParallelAttention(MegatronModule): ...@@ -123,6 +118,7 @@ class ParallelAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -183,10 +179,40 @@ class ParallelAttention(MegatronModule): ...@@ -183,10 +179,40 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None): def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
...@@ -227,18 +253,28 @@ class ParallelAttention(MegatronModule): ...@@ -227,18 +253,28 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
# ================================== # ==================================
if layer_past is not None: if inference_params:
past_key, past_value = layer_past batch_start = inference_params.batch_size_offset
key_layer = torch.cat((past_key.type_as(key_layer), batch_end = batch_start + key_layer.size(1)
key_layer), dim=0) assert batch_end <= inference_key_memory.size(1)
value_layer = torch.cat((past_value.type_as(value_layer), sequence_start = inference_params.sequence_len_offset
value_layer), dim=0) sequence_end = sequence_start + key_layer.size(0)
if get_key_value: assert sequence_end <= inference_key_memory.size(0)
present = (key_layer, value_layer) # Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
...@@ -275,22 +311,6 @@ class ParallelAttention(MegatronModule): ...@@ -275,22 +311,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
...@@ -346,9 +366,6 @@ class ParallelAttention(MegatronModule): ...@@ -346,9 +366,6 @@ class ParallelAttention(MegatronModule):
output, bias = self.dense(context_layer) output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias return output, bias
...@@ -366,14 +383,18 @@ def get_bias_dropout_add(training): ...@@ -366,14 +383,18 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob): def bias_dropout_add_fused_train(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob): def bias_dropout_add_fused_inference(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -436,20 +457,17 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -436,20 +457,17 @@ class ParallelTransformerLayer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False): inference_params=None):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(
attention_mask, layernorm_output,
layer_past=layer_past, attention_mask,
get_key_value=get_key_value) inference_params=inference_params)
if get_key_value:
attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -519,9 +537,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -519,9 +537,6 @@ class ParallelTransformerLayer(MegatronModule):
residual, residual,
self.hidden_dropout) self.hidden_dropout)
if get_key_value:
output = [output, presents]
return output return output
...@@ -547,9 +562,8 @@ class ParallelTransformer(MegatronModule): ...@@ -547,9 +562,8 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ self.num_layers = mpu.get_num_layers(
'num_layers must be divisible by pipeline_model_parallel_size' args, args.model_type == ModelType.encoder_and_decoder)
self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
...@@ -664,18 +678,14 @@ class ParallelTransformer(MegatronModule): ...@@ -664,18 +678,14 @@ class ParallelTransformer(MegatronModule):
forward_step_func""" forward_step_func"""
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks. # Checks.
if layer_past is not None: if inference_params:
assert get_key_value, \
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert self.activations_checkpoint_method is None, \ assert self.activations_checkpoint_method is None, \
'get_key_value does not work with ' \ 'inference does not work with activation checkpointing'
'activation checkpointing'
if self.pre_process: if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...@@ -698,22 +708,15 @@ class ParallelTransformer(MegatronModule): ...@@ -698,22 +708,15 @@ class ParallelTransformer(MegatronModule):
encoder_output, encoder_output,
enc_dec_attn_mask) enc_dec_attn_mask)
else: else:
if get_key_value:
presents = []
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
past = None hidden_states = layer(
if layer_past is not None: hidden_states,
past = layer_past[index] attention_mask,
hidden_states = layer(hidden_states, encoder_output=encoder_output,
attention_mask, enc_dec_attn_mask=enc_dec_attn_mask,
encoder_output=encoder_output, inference_params=inference_params)
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
...@@ -722,7 +725,5 @@ class ParallelTransformer(MegatronModule): ...@@ -722,7 +725,5 @@ class ParallelTransformer(MegatronModule):
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
if get_key_value:
output = [output, presents]
return output return output
...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group ...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank from .initialize import get_pipeline_model_parallel_last_rank
......
...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None ...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
...@@ -52,13 +56,19 @@ def is_unitialized(): ...@@ -52,13 +56,19 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1, pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None): virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor. tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized' 'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
pipeline_model_parallel_split_rank_ not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized(): def model_parallel_is_initialized():
...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank(): ...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder'
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder'
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
else:
num_layers = args.num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False): ...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank(): def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank.""" """Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment