Commit 52a5f2f2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP

parent 7abd3e90
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, intra-layer-model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision. [Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, tensor-model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%. Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.
...@@ -218,7 +218,7 @@ These scripts use the PyTorch distributed launcher for distributed training. As ...@@ -218,7 +218,7 @@ These scripts use the PyTorch distributed launcher for distributed training. As
The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time. The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.
Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--intra-layer-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--intra-layer-model-parallel-size` is 1, which will not implement model parallelism. Second, we developed a simple and efficient tensor model parallel approach. To use model parallelism, add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--tensor-model-parallel-size` is 1, which will not implement model parallelism.
Other than these minor changes, the distributed training is identical to the training on a single GPU. Other than these minor changes, the distributed training is identical to the training on a single GPU.
...@@ -245,7 +245,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \ ...@@ -245,7 +245,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--intra-layer-model-parallel-size $MP_SIZE \ --tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch --DDP-impl torch
</pre> </pre>
...@@ -269,7 +269,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \ ...@@ -269,7 +269,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--intra-layer-model-parallel-size $MP_SIZE \ --tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch --DDP-impl torch
</pre> </pre>
...@@ -362,14 +362,14 @@ We provide several command line arguments, detailed in the scripts listed below, ...@@ -362,14 +362,14 @@ We provide several command line arguments, detailed in the scripts listed below,
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
INTRA_LAYER_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \ --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT \ --load $CHECKPOINT \
--intra-layer-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--epochs 5 \ --epochs 5 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--intra-layer-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--epochs 3 \ --epochs 3 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--intra-layer-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json ...@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt MERGE_FILE=gpt2-merges.txt
python tools/generate_samples_gpt2.py \ python tools/generate_samples_gpt2.py \
--intra-layer-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
......
#!/bin/bash #!/bin/bash
INTRA_LAYER_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \ --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
......
...@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \ pretrain_bert.py \
--intra-layer-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -17,7 +17,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -17,7 +17,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \ pretrain_gpt2.py \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -54,14 +54,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -54,14 +54,14 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args. # Distributed args.
args.rank = int(os.getenv('RANK', '0')) args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1')) args.world_size = int(os.getenv("WORLD_SIZE", '1'))
args.intra_layer_model_parallel_size = min( args.tensor_model_parallel_size = min(
args.intra_layer_model_parallel_size, args.world_size) args.tensor_model_parallel_size, args.world_size)
args.inter_layer_model_parallel_size = min( args.pipeline_model_parallel_size = min(
args.inter_layer_model_parallel_size, args.pipeline_model_parallel_size,
(args.world_size // args.intra_layer_model_parallel_size)) (args.world_size // args.tensor_model_parallel_size))
if args.rank == 0: if args.rank == 0:
print('using world size: {} and intra-layer-model-parallel size: {} '.format( print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.intra_layer_model_parallel_size)) args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
# Fp16 loss scaling. # Fp16 loss scaling.
args.dynamic_loss_scale = False args.dynamic_loss_scale = False
...@@ -364,12 +364,12 @@ def _add_mixed_precision_args(parser): ...@@ -364,12 +364,12 @@ def _add_mixed_precision_args(parser):
def _add_distributed_args(parser): def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed') group = parser.add_argument_group(title='distributed')
group.add_argument('--intra-layer-model-parallel-size', type=int, default=1, group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of intra-layer model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--inter-layer-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of inter-layer model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--use-pipelining', action='store_true', group.add_argument('--use-pipelining', action='store_true',
help='Use pipelining to increase throughput of inter-layer model parallelism') help='Use pipelining to increase throughput of pipeline model parallelism')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
......
...@@ -59,7 +59,7 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,7 +59,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
_compare('intra_layer_model_parallel_size') _compare('tensor_model_parallel_size')
def ensure_directory_exists(filename): def ensure_directory_exists(filename):
...@@ -76,16 +76,16 @@ def get_checkpoint_name(checkpoints_path, iteration, ...@@ -76,16 +76,16 @@ def get_checkpoint_name(checkpoints_path, iteration,
directory = 'release' directory = 'release'
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the intra-layer and inter-layer MP rank. # Use both the tensor and pipeline MP rank.
if mpu.get_inter_layer_model_parallel_world_size() == 1: if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format( 'mp_rank_{:02d}'.format(
mpu.get_intra_layer_model_parallel_rank()), mpu.get_tensor_model_parallel_rank()),
'model_optim_rng.pt') 'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format( 'mp_rank_{:02d}_{:03d}'.format(
mpu.get_intra_layer_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank()), mpu.get_pipeline_model_parallel_rank()),
'model_optim_rng.pt') 'model_optim_rng.pt')
......
...@@ -153,10 +153,10 @@ def get_samples_mapping_(indexed_dataset, ...@@ -153,10 +153,10 @@ def get_samples_mapping_(indexed_dataset,
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == ( assert counts[0].item() == (
torch.distributed.get_world_size() // torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group())) torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset. # Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
......
...@@ -204,10 +204,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -204,10 +204,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == ( assert counts[0].item() == (
torch.distributed.get_world_size() // torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group())) torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load mappings. # Load mappings.
start_time = time.time() start_time = time.time()
......
...@@ -112,7 +112,7 @@ def main(): ...@@ -112,7 +112,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.intra_layer_model_parallel_size = 1 args.tensor_model_parallel_size = 1
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
......
...@@ -74,10 +74,10 @@ class FP16_Module(MegatronModule): ...@@ -74,10 +74,10 @@ class FP16_Module(MegatronModule):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
convert_inputs = True convert_inputs = True
convert_outputs = True convert_outputs = True
if mpu.get_inter_layer_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if not mpu.is_inter_layer_first_stage(): if not mpu.is_pipeline_first_stage():
convert_inputs = False convert_inputs = False
if not mpu.is_inter_layer_last_stage(): if not mpu.is_pipeline_last_stage():
convert_outputs = False convert_outputs = False
if convert_inputs: if convert_inputs:
inputs = fp32_to_fp16(inputs) inputs = fp32_to_fp16(inputs)
...@@ -227,7 +227,7 @@ class FP16_Optimizer(object): ...@@ -227,7 +227,7 @@ class FP16_Optimizer(object):
master_param = param.detach().clone().float() master_param = param.detach().clone().float()
master_param.requires_grad = True master_param.requires_grad = True
# Copythe model parallel flag. # Copythe model parallel flag.
master_param.intra_layer_model_parallel = param.intra_layer_model_parallel master_param.tensor_model_parallel = param.tensor_model_parallel
param_group['params'][i] = master_param param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param. # Reset existing state dict key to the new master param.
......
...@@ -26,7 +26,7 @@ from megatron import get_args ...@@ -26,7 +26,7 @@ from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import set_intra_layer_model_parallel_rank, set_intra_layer_model_parallel_world_size from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args.use_cpu_initialization=True args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size) set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized # and return function for external DDP manager to call when it has DDP initialized
set_intra_layer_model_parallel_rank(args.rank) set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
...@@ -121,14 +121,14 @@ def _initialize_distributed(): ...@@ -121,14 +121,14 @@ def _initialize_distributed():
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
init_method=init_method) init_method=init_method)
# Set the intra-layer model-parallel, inter-layer model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
if device_count > 0: if device_count > 0:
if mpu.model_parallel_is_initialized(): if mpu.model_parallel_is_initialized():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.intra_layer_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.inter_layer_model_parallel_size) args.pipeline_model_parallel_size)
def _init_autoresume(): def _init_autoresume():
...@@ -143,13 +143,13 @@ def _init_autoresume(): ...@@ -143,13 +143,13 @@ def _init_autoresume():
def _set_random_seed(seed_): def _set_random_seed(seed_):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
if seed_ is not None and seed_ > 0: if seed_ is not None and seed_ > 0:
# Ensure that different inter-layer MP stages get different seeds. # Ensure that different pipeline MP stages get different seeds.
seed = seed_ + mpu.get_inter_layer_model_parallel_rank() seed = seed_ + mpu.get_pipeline_model_parallel_rank()
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
mpu.intra_layer_model_parallel_cuda_manual_seed(seed) mpu.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.intra_layer_model_parallel = True self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = 1 self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
...@@ -150,8 +150,8 @@ class BertModelBase(MegatronModule): ...@@ -150,8 +150,8 @@ class BertModelBase(MegatronModule):
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_inter_layer_first_stage(): if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce # weights to 0 here, then copy first stage's weights using all_reduce
...@@ -172,14 +172,14 @@ class BertModelBase(MegatronModule): ...@@ -172,14 +172,14 @@ class BertModelBase(MegatronModule):
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
# Ensure that first and last stages have the same initial embedding weights. # Ensure that first and last stages have the same initial embedding weights.
if mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be ' raise Exception('word_embeddings_weight() should be '
'called for first and last stage only') 'called for first and last stage only')
...@@ -190,7 +190,7 @@ class BertModelBase(MegatronModule): ...@@ -190,7 +190,7 @@ class BertModelBase(MegatronModule):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
kwargs = {} kwargs = {}
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
input_ids = bert_model_input input_ids = bert_model_input
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask] args = [input_ids, position_ids, extended_attention_mask]
...@@ -198,12 +198,12 @@ class BertModelBase(MegatronModule): ...@@ -198,12 +198,12 @@ class BertModelBase(MegatronModule):
else: else:
args = [bert_model_input, extended_attention_mask] args = [bert_model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model(*args, **kwargs)
if mpu.is_inter_layer_last_stage() and self.add_binary_head: if mpu.is_pipeline_last_stage() and self.add_binary_head:
lm_output, pooled_output = lm_output lm_output, pooled_output = lm_output
else: else:
pooled_output = None pooled_output = None
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
return post_language_model_processing(lm_output, pooled_output, return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head, self.lm_head, self.binary_head,
lm_labels, lm_labels,
...@@ -222,15 +222,15 @@ class BertModelBase(MegatronModule): ...@@ -222,15 +222,15 @@ class BertModelBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage() and self.add_binary_head: if mpu.is_pipeline_last_stage() and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -240,14 +240,14 @@ class BertModelBase(MegatronModule): ...@@ -240,14 +240,14 @@ class BertModelBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
self.lm_head.load_state_dict( self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict) state_dict[self._lm_head_key], strict=strict)
if mpu.is_inter_layer_last_stage() and self.add_binary_head: if mpu.is_pipeline_last_stage() and self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings. # Load word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
......
...@@ -80,8 +80,8 @@ class GPT2ModelBase(MegatronModule): ...@@ -80,8 +80,8 @@ class GPT2ModelBase(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_inter_layer_first_stage(): if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce # weights to 0 here, then copy first stage's weights using all_reduce
...@@ -92,14 +92,14 @@ class GPT2ModelBase(MegatronModule): ...@@ -92,14 +92,14 @@ class GPT2ModelBase(MegatronModule):
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial embedding weights. # Ensure that first and last stages have the same initial embedding weights.
if mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be ' raise Exception('word_embeddings_weight() should be '
'called for first and last stage only') 'called for first and last stage only')
...@@ -109,7 +109,7 @@ class GPT2ModelBase(MegatronModule): ...@@ -109,7 +109,7 @@ class GPT2ModelBase(MegatronModule):
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = gpt2_model_input (input_ids, position_ids) = gpt2_model_input
args = [input_ids, position_ids, attention_mask] args = [input_ids, position_ids, attention_mask]
kwargs['tokentype_ids'] = tokentype_ids kwargs['tokentype_ids'] = tokentype_ids
...@@ -117,7 +117,7 @@ class GPT2ModelBase(MegatronModule): ...@@ -117,7 +117,7 @@ class GPT2ModelBase(MegatronModule):
args = [gpt2_model_input, attention_mask] args = [gpt2_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model(*args, **kwargs)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
...@@ -136,7 +136,7 @@ class GPT2ModelBase(MegatronModule): ...@@ -136,7 +136,7 @@ class GPT2ModelBase(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -145,7 +145,7 @@ class GPT2ModelBase(MegatronModule): ...@@ -145,7 +145,7 @@ class GPT2ModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Load word_embeddings. # Load word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
......
...@@ -29,7 +29,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -29,7 +29,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
# Parallel logits. # Parallel logits.
input_parallel = mpu.copy_to_intra_layer_model_parallel_region(input_) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
if bias is None: if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight) logits_parallel = F.linear(input_parallel, word_embeddings_weight)
...@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_from_intra_layer_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
...@@ -57,14 +57,14 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -57,14 +57,14 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args = [attention_mask_func, init_method, scaled_init_method] args = [attention_mask_func, init_method, scaled_init_method]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_inter_layer_first_stage() and mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_pooler'] = add_pooler kwargs['add_pooler'] = add_pooler
elif mpu.is_inter_layer_first_stage() and not mpu.is_inter_layer_last_stage(): elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage cls = TransformerLanguageModelFirstStage
kwargs['num_tokentypes'] = num_tokentypes kwargs['num_tokentypes'] = num_tokentypes
elif not mpu.is_inter_layer_first_stage() and mpu.is_inter_layer_last_stage(): elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelLastStage cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler kwargs['add_pooler'] = add_pooler
else: else:
...@@ -291,7 +291,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -291,7 +291,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size, args.padded_vocab_size,
args.max_position_embeddings, args.max_position_embeddings,
...@@ -307,7 +307,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -307,7 +307,7 @@ class TransformerLanguageModelBase(MegatronModule):
self._transformer_key = 'transformer' self._transformer_key = 'transformer'
# Pooler. # Pooler.
if mpu.is_inter_layer_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
...@@ -316,7 +316,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -316,7 +316,7 @@ class TransformerLanguageModelBase(MegatronModule):
pooling_sequence_index=0): pooling_sequence_index=0):
# Embeddings. # Embeddings.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input (input_ids, position_ids) = language_model_input
embedding_output = self.embedding(input_ids, position_ids, embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
...@@ -330,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -330,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
if mpu.is_inter_layer_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output, pooled_output = self.pooler(transformer_output,
pooling_sequence_index) pooling_sequence_index)
return transformer_output, pooled_output return transformer_output, pooled_output
...@@ -342,14 +342,14 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -342,14 +342,14 @@ class TransformerLanguageModelBase(MegatronModule):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._transformer_key] \ state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint( = self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -360,7 +360,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -360,7 +360,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Embedding. # Embedding.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
if self._embedding_key in state_dict: if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key] state_dict_ = state_dict[self._embedding_key]
else: else:
...@@ -383,7 +383,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -383,7 +383,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.transformer.load_state_dict(state_dict_, strict=strict) self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler. # Pooler.
if mpu.is_inter_layer_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage() and self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
......
...@@ -19,7 +19,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False): ...@@ -19,7 +19,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
assert args.ict_head_size is not None, \ assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel" "Need to specify --ict-head-size to provide an ICTBertModel"
assert args.intra_layer_model_parallel_size == 1, \ assert args.tensor_model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT" "Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...') print_rank_0('building ICTBertModel...')
......
...@@ -130,7 +130,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -130,7 +130,7 @@ class ParallelSelfAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_intra_layer_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size, self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
...@@ -505,12 +505,12 @@ class ParallelTransformer(MegatronModule): ...@@ -505,12 +505,12 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers. # Number of layers.
self.num_layers = args.num_layers // args.inter_layer_model_parallel_size self.num_layers = args.num_layers // args.pipeline_model_parallel_size
# TODO: Need to do something different in case self.num_layers != self.num_unique_layers? # TODO: Need to do something different in case self.num_layers != self.num_unique_layers?
if args.num_unique_layers is None: if args.num_unique_layers is None:
self.num_unique_layers = self.num_layers self.num_unique_layers = self.num_layers
else: else:
self.num_unique_layers = args.num_unique_layers // args.inter_layer_model_parallel_size self.num_unique_layers = args.num_unique_layers // args.pipeline_model_parallel_size
assert self.num_layers == self.num_unique_layers, \ assert self.num_layers == self.num_unique_layers, \
'number of layers should be equal to the number of unique layers' 'number of layers should be equal to the number of unique layers'
self.param_sharing_style = args.param_sharing_style self.param_sharing_style = args.param_sharing_style
...@@ -520,7 +520,7 @@ class ParallelTransformer(MegatronModule): ...@@ -520,7 +520,7 @@ class ParallelTransformer(MegatronModule):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, attention_mask_func, init_method,
output_layer_init_method, layer_number) output_layer_init_method, layer_number)
offset = mpu.get_inter_layer_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_unique_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_unique_layers)])
...@@ -533,7 +533,7 @@ class ParallelTransformer(MegatronModule): ...@@ -533,7 +533,7 @@ class ParallelTransformer(MegatronModule):
'{:3d}'.format(i, self._get_layer_index(i)), '{:3d}'.format(i, self._get_layer_index(i)),
flush=True) flush=True)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
...@@ -610,7 +610,7 @@ class ParallelTransformer(MegatronModule): ...@@ -610,7 +610,7 @@ class ParallelTransformer(MegatronModule):
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm. # Final layer norm.
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment