Commit 52a5f2f2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP

parent 7abd3e90
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, intra-layer-model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, tensor-model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.
......@@ -218,7 +218,7 @@ These scripts use the PyTorch distributed launcher for distributed training. As
The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.
Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--intra-layer-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--intra-layer-model-parallel-size` is 1, which will not implement model parallelism.
Second, we developed a simple and efficient tensor model parallel approach. To use model parallelism, add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--tensor-model-parallel-size` is 1, which will not implement model parallelism.
Other than these minor changes, the distributed training is identical to the training on a single GPU.
......@@ -245,7 +245,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--intra-layer-model-parallel-size $MP_SIZE \
--tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch
</pre>
......@@ -269,7 +269,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--intra-layer-model-parallel-size $MP_SIZE \
--tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch
</pre>
......@@ -362,14 +362,14 @@ We provide several command line arguments, detailed in the scripts listed below,
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre>
INTRA_LAYER_MODEL_PARALLEL_SIZE=2
TENSOR_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
......
......@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--load $CHECKPOINT \
--intra-layer-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......
......@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--epochs 5 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--intra-layer-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......
......@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--epochs 3 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--intra-layer-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......
......@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
python tools/generate_samples_gpt2.py \
--intra-layer-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load $CHECKPOINT_PATH \
......
#!/bin/bash
INTRA_LAYER_MODEL_PARALLEL_SIZE=2
TENSOR_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
......
......@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--intra-layer-model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......
......@@ -17,7 +17,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......
......@@ -54,14 +54,14 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
args.intra_layer_model_parallel_size = min(
args.intra_layer_model_parallel_size, args.world_size)
args.inter_layer_model_parallel_size = min(
args.inter_layer_model_parallel_size,
(args.world_size // args.intra_layer_model_parallel_size))
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
if args.rank == 0:
print('using world size: {} and intra-layer-model-parallel size: {} '.format(
args.world_size, args.intra_layer_model_parallel_size))
print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
# Fp16 loss scaling.
args.dynamic_loss_scale = False
......@@ -364,12 +364,12 @@ def _add_mixed_precision_args(parser):
def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed')
group.add_argument('--intra-layer-model-parallel-size', type=int, default=1,
help='Degree of intra-layer model parallelism.')
group.add_argument('--inter-layer-model-parallel-size', type=int, default=1,
help='Degree of inter-layer model parallelism.')
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--use-pipelining', action='store_true',
help='Use pipelining to increase throughput of inter-layer model parallelism')
help='Use pipelining to increase throughput of pipeline model parallelism')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
......
......@@ -59,7 +59,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('intra_layer_model_parallel_size')
_compare('tensor_model_parallel_size')
def ensure_directory_exists(filename):
......@@ -76,16 +76,16 @@ def get_checkpoint_name(checkpoints_path, iteration,
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the intra-layer and inter-layer MP rank.
if mpu.get_inter_layer_model_parallel_world_size() == 1:
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_intra_layer_model_parallel_rank()),
mpu.get_tensor_model_parallel_rank()),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_intra_layer_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank()),
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()),
'model_optim_rng.pt')
......
......@@ -153,10 +153,10 @@ def get_samples_mapping_(indexed_dataset,
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group()))
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
......
......@@ -204,10 +204,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group()))
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load mappings.
start_time = time.time()
......
......@@ -112,7 +112,7 @@ def main():
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.intra_layer_model_parallel_size = 1
args.tensor_model_parallel_size = 1
if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
......
......@@ -74,10 +74,10 @@ class FP16_Module(MegatronModule):
def forward(self, *inputs, **kwargs):
convert_inputs = True
convert_outputs = True
if mpu.get_inter_layer_model_parallel_world_size() > 1:
if not mpu.is_inter_layer_first_stage():
if mpu.get_pipeline_model_parallel_world_size() > 1:
if not mpu.is_pipeline_first_stage():
convert_inputs = False
if not mpu.is_inter_layer_last_stage():
if not mpu.is_pipeline_last_stage():
convert_outputs = False
if convert_inputs:
inputs = fp32_to_fp16(inputs)
......@@ -227,7 +227,7 @@ class FP16_Optimizer(object):
master_param = param.detach().clone().float()
master_param.requires_grad = True
# Copythe model parallel flag.
master_param.intra_layer_model_parallel = param.intra_layer_model_parallel
master_param.tensor_model_parallel = param.tensor_model_parallel
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
......
......@@ -26,7 +26,7 @@ from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables
from megatron.mpu import set_intra_layer_model_parallel_rank, set_intra_layer_model_parallel_world_size
from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
......@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size)
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized
set_intra_layer_model_parallel_rank(args.rank)
set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
......@@ -121,14 +121,14 @@ def _initialize_distributed():
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the intra-layer model-parallel, inter-layer model-parallel, and
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.intra_layer_model_parallel_size,
args.inter_layer_model_parallel_size)
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
def _init_autoresume():
......@@ -143,13 +143,13 @@ def _init_autoresume():
def _set_random_seed(seed_):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different inter-layer MP stages get different seeds.
seed = seed_ + mpu.get_inter_layer_model_parallel_rank()
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + mpu.get_pipeline_model_parallel_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.intra_layer_model_parallel_cuda_manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
......@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.intra_layer_model_parallel = True
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output
......@@ -150,8 +150,8 @@ class BertModelBase(MegatronModule):
init_method=init_method,
scaled_init_method=scaled_init_method)
if mpu.is_inter_layer_last_stage():
if not mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
......@@ -172,14 +172,14 @@ class BertModelBase(MegatronModule):
self._binary_head_key = 'binary_head'
# Ensure that first and last stages have the same initial embedding weights.
if mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
def word_embeddings_weight(self):
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
......@@ -190,7 +190,7 @@ class BertModelBase(MegatronModule):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
kwargs = {}
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
......@@ -198,12 +198,12 @@ class BertModelBase(MegatronModule):
else:
args = [bert_model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_inter_layer_last_stage() and self.add_binary_head:
if mpu.is_pipeline_last_stage() and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
......@@ -222,15 +222,15 @@ class BertModelBase(MegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage() and self.add_binary_head:
if mpu.is_pipeline_last_stage() and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
......@@ -240,14 +240,14 @@ class BertModelBase(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if mpu.is_inter_layer_last_stage() and self.add_binary_head:
if mpu.is_pipeline_last_stage() and self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
......
......@@ -80,8 +80,8 @@ class GPT2ModelBase(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
if mpu.is_inter_layer_last_stage():
if not mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
......@@ -92,14 +92,14 @@ class GPT2ModelBase(MegatronModule):
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial embedding weights.
if mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
def word_embeddings_weight(self):
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
......@@ -109,7 +109,7 @@ class GPT2ModelBase(MegatronModule):
forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = gpt2_model_input
args = [input_ids, position_ids, attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
......@@ -117,7 +117,7 @@ class GPT2ModelBase(MegatronModule):
args = [gpt2_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
return post_language_model_processing(
lm_output, labels,
self.word_embeddings_weight(),
......@@ -136,7 +136,7 @@ class GPT2ModelBase(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
......@@ -145,7 +145,7 @@ class GPT2ModelBase(MegatronModule):
"""Customized load."""
# Load word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
......
......@@ -29,7 +29,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_intra_layer_model_parallel_region(input_)
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
......@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output:
return logits_parallel
return mpu.gather_from_intra_layer_model_parallel_region(logits_parallel)
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
......@@ -57,14 +57,14 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args = [attention_mask_func, init_method, scaled_init_method]
kwargs = {}
cls = None
if mpu.is_inter_layer_first_stage() and mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_pooler'] = add_pooler
elif mpu.is_inter_layer_first_stage() and not mpu.is_inter_layer_last_stage():
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
kwargs['num_tokentypes'] = num_tokentypes
elif not mpu.is_inter_layer_first_stage() and mpu.is_inter_layer_last_stage():
elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler
else:
......@@ -291,7 +291,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.add_pooler = add_pooler
# Embeddings.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
......@@ -307,7 +307,7 @@ class TransformerLanguageModelBase(MegatronModule):
self._transformer_key = 'transformer'
# Pooler.
if mpu.is_inter_layer_last_stage() and self.add_pooler:
if mpu.is_pipeline_last_stage() and self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
......@@ -316,7 +316,7 @@ class TransformerLanguageModelBase(MegatronModule):
pooling_sequence_index=0):
# Embeddings.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
......@@ -330,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
layer_past=layer_past,
get_key_value=get_key_value)
if mpu.is_inter_layer_last_stage() and self.add_pooler:
if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output,
pooling_sequence_index)
return transformer_output, pooled_output
......@@ -342,14 +342,14 @@ class TransformerLanguageModelBase(MegatronModule):
"""For easy load."""
state_dict_ = {}
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage() and self.add_pooler:
if mpu.is_pipeline_last_stage() and self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
......@@ -360,7 +360,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load."""
# Embedding.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
......@@ -383,7 +383,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if mpu.is_inter_layer_last_stage() and self.add_pooler:
if mpu.is_pipeline_last_stage() and self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
......
......@@ -19,7 +19,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.intra_layer_model_parallel_size == 1, \
assert args.tensor_model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
......
......@@ -130,7 +130,7 @@ class ParallelSelfAttention(MegatronModule):
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
world_size = mpu.get_intra_layer_model_parallel_world_size()
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
......@@ -505,12 +505,12 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers.
self.num_layers = args.num_layers // args.inter_layer_model_parallel_size
self.num_layers = args.num_layers // args.pipeline_model_parallel_size
# TODO: Need to do something different in case self.num_layers != self.num_unique_layers?
if args.num_unique_layers is None:
self.num_unique_layers = self.num_layers
else:
self.num_unique_layers = args.num_unique_layers // args.inter_layer_model_parallel_size
self.num_unique_layers = args.num_unique_layers // args.pipeline_model_parallel_size
assert self.num_layers == self.num_unique_layers, \
'number of layers should be equal to the number of unique layers'
self.param_sharing_style = args.param_sharing_style
......@@ -520,7 +520,7 @@ class ParallelTransformer(MegatronModule):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number)
offset = mpu.get_inter_layer_model_parallel_rank() * self.num_layers
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_unique_layers)])
......@@ -533,7 +533,7 @@ class ParallelTransformer(MegatronModule):
'{:3d}'.format(i, self._get_layer_index(i)),
flush=True)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
......@@ -610,7 +610,7 @@ class ParallelTransformer(MegatronModule):
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm.
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment