Commit ed6d28b1 authored by mshoeybi's avatar mshoeybi
Browse files

merged main

parents 8c119d80 1ec6b0e9
__pycache__ __pycache__
# Distribution / packaging
build/
dist/
*.egg-info/
\ No newline at end of file
include megatron/data/Makefile
include megatron/data/helpers.cpp
...@@ -48,13 +48,6 @@ We have tested Megatron with [NGC's PyTorch container](https://ngc.nvidia.com/ca ...@@ -48,13 +48,6 @@ We have tested Megatron with [NGC's PyTorch container](https://ngc.nvidia.com/ca
To use this repository, please install the latest supported versions of PyTorch with GPU support (python 3.8, pytorch 1.8, cuda 11.1, and nccl 2.8.3 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We strongly recommend using one of [NGC's recent PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) (the latest compatible version at time of publication can be pulled with `docker pull nvcr.io/nvidia/pytorch:20.12-py3`). Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. To use this repository, please install the latest supported versions of PyTorch with GPU support (python 3.8, pytorch 1.8, cuda 11.1, and nccl 2.8.3 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We strongly recommend using one of [NGC's recent PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) (the latest compatible version at time of publication can be pulled with `docker pull nvcr.io/nvidia/pytorch:20.12-py3`). Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks.
<!--
To use megatron you can either clone the repo or install it via pip (make sure python3-dev is installed):
<pre>
pip install megatron-lm
</pre>
-->
## Downloading Checkpoints ## Downloading Checkpoints
We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1). We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1).
...@@ -433,33 +426,23 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ ...@@ -433,33 +426,23 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation ## GPT Text Generation
`bash examples/generate_text.sh`
We generate text samples using largely the GPT pretraining script. Few changes need to make, such as we need to provide the path to the pretrained checkpoint, the length of the output samples, whether to generate texts unconditionally (`--num-samples` to denote how many samples to generate) or conditional (need to pass `--sample-input-file <filename>` where each line of the file will be used as the conditional texts). There are few optional parameters to play, e.g. `top-k`, `top-p`, or `greedy` (set top-k and top-p to 0) sampling.. We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
<pre> <pre>
CHECKPOINT_PATH=checkpoints/gpt2_345m tools/text_generation_cli.py localhost
VOCAB_FILE=gpt2-vocab.json </pre>
MERGE_FILE=gpt2-merges.txt
GPT_ARGS=&#60;same as those in <a href="#gpt-pretraining">GPT pretraining</a> above&#62;
MAX_OUTPUT_SEQUENCE_LENGTH=1024 You can also use CURL or any other tools to query the server directly:
TEMPERATURE=1.0
TOP_P=0.9
NUMBER_OF_SAMPLES=2
OUTPUT_FILE=samples.json
python tools/generate_samples_gpt.py \ <pre>
$GPT_ARGS \ curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["Hello world"], "tokens_to_generate":1}'
--load $CHECKPOINT_PATH \
--out-seq-length $MAX_OUTPUT_SEQUENCE_LENGTH \
--temperature $TEMPERATURE \
--genfile $OUTPUT_FILE \
--num-samples $NUMBER_OF_SAMPLES \
--top_p $TOP_P \
--recompute
</pre> </pre>
See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options.
## GPT Evaluation ## GPT Evaluation
We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy. We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.
......
#!/bin/bash
CHECKPOINT_PATH=checkpoints/gpt2_345m
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
python tools/generate_samples_gpt2.py \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load $CHECKPOINT_PATH \
--num-attention-heads 16 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--batch-size 2 \
--seq-length 1024 \
--out-seq-length 1024 \
--temperature 1.0 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--genfile unconditional_samples.json \
--num-samples 2 \
--top_p 0.9 \
--recompute
...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)> ...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
pip install flask-restful pip install flask-restful
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py / python -m torch.distributed.run $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 / --tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 / --pipeline-model-parallel-size 1 \
--num-layers 24 / --num-layers 24 \
--hidden-size 1024 / --hidden-size 1024 \
--load ${CHECKPOINT} / --load ${CHECKPOINT} \
--num-attention-heads 16 / --num-attention-heads 16 \
--max-position-embeddings 1024 / --max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer / --tokenizer-type GPT2BPETokenizer \
--fp16 / --fp16 \
--micro-batch-size 1 / --micro-batch-size 1 \
--seq-length 1024 / --seq-length 1024 \
--out-seq-length 1024 / --out-seq-length 1024 \
--temperature 1.0 / --temperature 1.0 \
--vocab-file $VOCAB_FILE / --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE / --merge-file $MERGE_FILE \
--top_p 0.9 / --top_p 0.9 \
--seed 42 --seed 42
...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)> ...@@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>
pip install flask-restful pip install flask-restful
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py / python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 8 / --tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 / --pipeline-model-parallel-size 1 \
--num-layers 24 / --num-layers 24 \
--hidden-size 1024 / --hidden-size 1024 \
--load ${CHECKPOINT} / --load ${CHECKPOINT} \
--num-attention-heads 16 / --num-attention-heads 16 \
--max-position-embeddings 1024 / --max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer / --tokenizer-type GPT2BPETokenizer \
--fp16 / --fp16 \
--micro-batch-size 1 / --micro-batch-size 1 \
--seq-length 1024 / --seq-length 1024 \
--out-seq-length 1024 / --out-seq-length 1024 \
--temperature 1.0 / --temperature 1.0 \
--vocab-file $VOCAB_FILE / --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE / --merge-file $MERGE_FILE \
--top_p 0.9 / --top_p 0.9 \
--seed 42 --seed 42
...@@ -14,17 +14,6 @@ ...@@ -14,17 +14,6 @@
# limitations under the License. # limitations under the License.
import torch import torch
from .package_info import (
__description__,
__contact_names__,
__url__,
__download_url__,
__keywords__,
__license__,
__package_name__,
__version__,
)
from .global_vars import get_args from .global_vars import get_args
from .global_vars import get_current_global_batch_size from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches from .global_vars import get_num_microbatches
......
...@@ -81,6 +81,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -81,6 +81,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size, args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True) args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments # Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \ assert args.batch_size is None, '--batch-size argument is no longer ' \
...@@ -614,6 +620,9 @@ def _add_distributed_args(parser): ...@@ -614,6 +620,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "ATen/ATen.h" #include "ATen/ATen.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh> #include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -329,6 +329,7 @@ void cuApplyLayerNorm( ...@@ -329,6 +329,7 @@ void cuApplyLayerNorm(
mean[i1] = mu; mean[i1] = mu;
invvar[i1] = c_invvar; invvar[i1] = c_invvar;
} }
__syncthreads();
} }
} }
......
...@@ -190,7 +190,8 @@ def _initialize_distributed(): ...@@ -190,7 +190,8 @@ def _initialize_distributed():
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size) args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
def _init_autoresume(): def _init_autoresume():
......
...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel ...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from .t5_model import T5Model from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
from .enums import ModelType
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
import enum import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
......
...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True): pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler, ...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method, scaled_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder, add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type, decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler, add_pooler=add_pooler,
...@@ -159,6 +161,16 @@ class Embedding(MegatronModule): ...@@ -159,6 +161,16 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_encoder=True,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False, add_pooler=False,
...@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule): ...@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
...@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule): ...@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.encoder = ParallelTransformer( # Encoder (usually set to True, False if part of an encoder-decoder
self.init_method, # architecture and in encoder-only stage).
output_layer_init_method, if self.add_encoder:
self_attn_mask_type=self.encoder_attn_mask_type, self.encoder = ParallelTransformer(
pre_process=self.pre_process, self.init_method,
post_process=self.post_process output_layer_init_method,
) self_attn_mask_type=self.encoder_attn_mask_type,
self._encoder_key = 'encoder' pre_process=self.pre_process,
post_process=self.post_process
# Decoder )
self._encoder_key = 'encoder'
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert args.pipeline_model_parallel_size == 1, \ assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder' 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
else:
self.decoder = None
if self.post_process: if self.post_process:
# Pooler. # Pooler.
...@@ -330,7 +357,31 @@ class TransformerLanguageModel(MegatronModule): ...@@ -330,7 +357,31 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
...@@ -339,20 +390,22 @@ class TransformerLanguageModel(MegatronModule): ...@@ -339,20 +390,22 @@ class TransformerLanguageModel(MegatronModule):
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Encoder embedding.
if self.pre_process: if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
encoder_input = embedding_output
else: else:
encoder_input = None encoder_input = None
# encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
encoder_output = self.encoder( if self.encoder is not None:
encoder_input, encoder_output = self.encoder(
enc_attn_mask, encoder_input,
inference_params=inference_params) enc_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder_hidden_state
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
...@@ -370,12 +423,16 @@ class TransformerLanguageModel(MegatronModule): ...@@ -370,12 +423,16 @@ class TransformerLanguageModel(MegatronModule):
else: else:
return encoder_output return encoder_output
# Decoder Embedding # Decoder embedding.
dec_embedding_output = self.embedding(dec_input_ids, if self.pre_process:
dec_position_ids) decoder_input = self.embedding(dec_input_ids,
# decoder dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder( decoder_output = self.decoder(
dec_embedding_output, decoder_input,
dec_attn_mask, dec_attn_mask,
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
...@@ -395,9 +452,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -395,9 +452,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._encoder_key] \ if self.add_encoder:
= self.encoder.state_dict_for_save_checkpoint( state_dict_[self._encoder_key] \
destination, prefix, keep_vars) = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
...@@ -426,38 +484,39 @@ class TransformerLanguageModel(MegatronModule): ...@@ -426,38 +484,39 @@ class TransformerLanguageModel(MegatronModule):
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder. # Encoder.
if self._encoder_key in state_dict: if self.add_encoder:
state_dict_ = state_dict[self._encoder_key] if self._encoder_key in state_dict:
# for backward compatibility. state_dict_ = state_dict[self._encoder_key]
elif 'transformer' in state_dict: # For backward compatibility.
state_dict_ = state_dict['transformer'] elif 'transformer' in state_dict:
else: state_dict_ = state_dict['transformer']
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else: else:
state_dict_self_attention[key] = state_dict_[key] # For backward compatibility.
state_dict_ = state_dict_self_attention state_dict_ = {}
for key in state_dict.keys():
self.encoder.load_state_dict(state_dict_, strict=strict) if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process: if self.post_process:
# pooler
if self.add_pooler: if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
# decoder # Decoder.
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
......
...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module): ...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True): if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True): else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module): ...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage # This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline # when we are using pipeline parallelism. Nothing to do if we aren't
# parallelism there is nothing to do. # using pipeline parallelism.
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
return return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
# workers, so we do the following: # workers, so we do the following:
...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module): ...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else: else:
print("WARNING! Distributed processes aren't initialized, so " print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. " "word embeddings in the last layer are not initialized. "
......
...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule): ...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class T5Model(MegatronModule): class T5Model(MegatronModule):
"""T5 Language model.""" """T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(T5Model, self).__init__() super(T5Model, self).__init__()
args = get_args() args = get_args()
...@@ -95,19 +101,29 @@ class T5Model(MegatronModule): ...@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
add_decoder=True, add_encoder=add_encoder,
add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.lm_head = T5LMHead( self.initialize_word_embeddings(init_method_normal)
self.language_model.embedding.word_embeddings.weight.size(0),
parallel_output) if self.post_process and self.add_decoder:
self._lm_head_key = 'lm_head' self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0),
parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
...@@ -134,22 +150,28 @@ class T5Model(MegatronModule): ...@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states) enc_hidden_states=enc_hidden_states)
decoder_output, encoder_output = lm_output if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output. # Output.
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.language_model.embedding.word_embeddings.weight) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits, encoder_output return lm_logits
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), if self.fp16_lm_cross_entropy:
lm_labels) assert lm_logits.dtype == torch.half
return lm_loss, encoder_output lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
return decoder_output
else:
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -160,9 +182,14 @@ class T5Model(MegatronModule): ...@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \ if self.post_process and self.add_decoder:
= self.lm_head.state_dict_for_save_checkpoint( state_dict_[self._lm_head_key] \
destination, prefix, keep_vars) = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -170,5 +197,10 @@ class T5Model(MegatronModule): ...@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key], if self.post_process and self.add_decoder:
strict=strict) self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
# Load word embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
...@@ -21,7 +21,7 @@ import torch.nn.functional as F ...@@ -21,7 +21,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
...@@ -389,14 +389,18 @@ def get_bias_dropout_add(training): ...@@ -389,14 +389,18 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob): def bias_dropout_add_fused_train(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob): def bias_dropout_add_fused_inference(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -564,9 +568,8 @@ class ParallelTransformer(MegatronModule): ...@@ -564,9 +568,8 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ self.num_layers = mpu.get_num_layers(
'num_layers must be divisible by pipeline_model_parallel_size' args, args.model_type == ModelType.encoder_and_decoder)
self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
......
...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group ...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank from .initialize import get_pipeline_model_parallel_last_rank
......
...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None ...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
...@@ -52,13 +56,19 @@ def is_unitialized(): ...@@ -52,13 +56,19 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1, pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None): virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor. tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized' 'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
pipeline_model_parallel_split_rank_ not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized(): def model_parallel_is_initialized():
...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank(): ...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder'
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder'
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
else:
num_layers = args.num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False): ...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank(): def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank.""" """Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
......
...@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach() if grad_not_none:
grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor' assert param.grad.type() == 'torch.cuda.FloatTensor'
......
...@@ -179,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -179,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
a `main_grad` field. If this is set, we are assuming a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad` that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32 to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad. and as a result we store those gradients in the main_grad.
...@@ -312,7 +312,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -312,7 +312,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups): self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad: if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float() main_param.grad = model_param.main_grad.float()
else: else:
if model_param.grad is not None: if model_param.grad is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment