Commit 7abd3e90 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Pipeline parallelism implementation with periodic full-pipeline syncs

Also includes following changes for inter-layer model-parallel implementation:
- Refactoring of model implementations
- Training loop changes to support inter-layer communication using `ring_exchange`
- New groups for inter-layer communication
- Checkpoint changes
- Command line arguments
parent 28cd66e1
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision. [Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, intra-layer-model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%. Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.
...@@ -218,7 +218,7 @@ These scripts use the PyTorch distributed launcher for distributed training. As ...@@ -218,7 +218,7 @@ These scripts use the PyTorch distributed launcher for distributed training. As
The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time. The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.
Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--model-parallel-size` is 1, which will not implement model parallelism. Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--intra-layer-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--intra-layer-model-parallel-size` is 1, which will not implement model parallelism.
Other than these minor changes, the distributed training is identical to the training on a single GPU. Other than these minor changes, the distributed training is identical to the training on a single GPU.
...@@ -245,7 +245,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \ ...@@ -245,7 +245,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \ --intra-layer-model-parallel-size $MP_SIZE \
--DDP-impl torch --DDP-impl torch
</pre> </pre>
...@@ -269,7 +269,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \ ...@@ -269,7 +269,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \ --intra-layer-model-parallel-size $MP_SIZE \
--DDP-impl torch --DDP-impl torch
</pre> </pre>
...@@ -362,14 +362,14 @@ We provide several command line arguments, detailed in the scripts listed below, ...@@ -362,14 +362,14 @@ We provide several command line arguments, detailed in the scripts listed below,
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
MODEL_PARALLEL_SIZE=2 INTRA_LAYER_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \ --intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT \ --load $CHECKPOINT \
--model-parallel-size 1 \ --intra-layer-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--epochs 5 \ --epochs 5 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \ --intra-layer-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--epochs 3 \ --epochs 3 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \ --intra-layer-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json ...@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt MERGE_FILE=gpt2-merges.txt
python tools/generate_samples_gpt2.py \ python tools/generate_samples_gpt2.py \
--model-parallel-size 1 \ --intra-layer-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
......
#!/bin/bash #!/bin/bash
MODEL_PARALLEL_SIZE=2 INTRA_LAYER_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \ --intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
......
...@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \ pretrain_bert.py \
--model-parallel-size 1 \ --intra-layer-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -33,9 +33,18 @@ from .global_vars import get_timers ...@@ -33,9 +33,18 @@ from .global_vars import get_timers
from .initialize import initialize_megatron from .initialize import initialize_megatron
def print_rank_0(message): def print_rank_0(message):
"""If distributed is initialized print only on rank 0.""" """If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(message, flush=True) print(message, flush=True)
else: else:
print(message, flush=True) print(message, flush=True)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(message, flush=True)
else:
print(message, flush=True)
...@@ -54,10 +54,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -54,10 +54,14 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args. # Distributed args.
args.rank = int(os.getenv('RANK', '0')) args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1')) args.world_size = int(os.getenv("WORLD_SIZE", '1'))
args.model_parallel_size = min(args.model_parallel_size, args.world_size) args.intra_layer_model_parallel_size = min(
args.intra_layer_model_parallel_size, args.world_size)
args.inter_layer_model_parallel_size = min(
args.inter_layer_model_parallel_size,
(args.world_size // args.intra_layer_model_parallel_size))
if args.rank == 0: if args.rank == 0:
print('using world size: {} and model-parallel size: {} '.format( print('using world size: {} and intra-layer-model-parallel size: {} '.format(
args.world_size, args.model_parallel_size)) args.world_size, args.intra_layer_model_parallel_size))
# Fp16 loss scaling. # Fp16 loss scaling.
args.dynamic_loss_scale = False args.dynamic_loss_scale = False
...@@ -192,7 +196,7 @@ def _add_regularization_args(parser): ...@@ -192,7 +196,7 @@ def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization') group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1, group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout ptobability.') help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1, group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.') help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01, group.add_argument('--weight-decay', type=float, default=0.01,
...@@ -358,10 +362,14 @@ def _add_mixed_precision_args(parser): ...@@ -358,10 +362,14 @@ def _add_mixed_precision_args(parser):
def _add_distributed_args(parser): def _add_distributed_args(parser):
group = parser.add_argument_group(title='mixed precision') group = parser.add_argument_group(title='distributed')
group.add_argument('--model-parallel-size', type=int, default=1, group.add_argument('--intra-layer-model-parallel-size', type=int, default=1,
help='Size of the model parallel.') help='Degree of intra-layer model parallelism.')
group.add_argument('--inter-layer-model-parallel-size', type=int, default=1,
help='Degree of inter-layer model parallelism.')
group.add_argument('--use-pipelining', action='store_true',
help='Use pipelining to increase throughput of inter-layer model parallelism')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
......
...@@ -59,7 +59,7 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,7 +59,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
_compare('model_parallel_size') _compare('intra_layer_model_parallel_size')
def ensure_directory_exists(filename): def ensure_directory_exists(filename):
...@@ -70,16 +70,22 @@ def ensure_directory_exists(filename): ...@@ -70,16 +70,22 @@ def ensure_directory_exists(filename):
def get_checkpoint_name(checkpoints_path, iteration, def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None): release=False):
"""A unified checkpoint name.""" """A unified checkpoint name."""
if release: if release:
directory = 'release' directory = 'release'
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the intra-layer and inter-layer MP rank.
if mpu.get_inter_layer_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_intra_layer_model_parallel_rank()),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format( 'mp_rank_{:02d}_{:03d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None mpu.get_intra_layer_model_parallel_rank(),
else mp_rank), mpu.get_inter_layer_model_parallel_rank()),
'model_optim_rng.pt') 'model_optim_rng.pt')
......
...@@ -153,8 +153,10 @@ def get_samples_mapping_(indexed_dataset, ...@@ -153,8 +153,10 @@ def get_samples_mapping_(indexed_dataset,
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size( torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group())
group=mpu.get_data_parallel_group()) assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group()))
# Load indexed dataset. # Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
......
...@@ -204,8 +204,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -204,8 +204,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size( torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group())
group=mpu.get_data_parallel_group()) assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group()))
# Load mappings. # Load mappings.
start_time = time.time() start_time = time.time()
......
...@@ -112,7 +112,7 @@ def main(): ...@@ -112,7 +112,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1 args.intra_layer_model_parallel_size = 1
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
......
...@@ -26,6 +26,7 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -26,6 +26,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron import mpu
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
...@@ -71,7 +72,19 @@ class FP16_Module(MegatronModule): ...@@ -71,7 +72,19 @@ class FP16_Module(MegatronModule):
self.add_module('module', module.half()) self.add_module('module', module.half())
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) convert_inputs = True
convert_outputs = True
if mpu.get_inter_layer_model_parallel_world_size() > 1:
if not mpu.is_inter_layer_first_stage():
convert_inputs = False
if not mpu.is_inter_layer_last_stage():
convert_outputs = False
if convert_inputs:
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if convert_outputs:
outputs = fp16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(destination, prefix, keep_vars)
...@@ -214,7 +227,7 @@ class FP16_Optimizer(object): ...@@ -214,7 +227,7 @@ class FP16_Optimizer(object):
master_param = param.detach().clone().float() master_param = param.detach().clone().float()
master_param.requires_grad = True master_param.requires_grad = True
# Copythe model parallel flag. # Copythe model parallel flag.
master_param.model_parallel = param.model_parallel master_param.intra_layer_model_parallel = param.intra_layer_model_parallel
param_group['params'][i] = master_param param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param. # Reset existing state dict key to the new master param.
...@@ -512,7 +525,8 @@ class FP16_Optimizer(object): ...@@ -512,7 +525,8 @@ class FP16_Optimizer(object):
return retval return retval
def backward(self, loss, update_master_grads=True, retain_graph=False): def backward(self, output_tensor, update_master_grads=True, retain_graph=False,
output_tensor_grad=None):
""" """
:attr:`backward` performs the following conceptual steps: :attr:`backward` performs the following conceptual steps:
...@@ -570,7 +584,9 @@ class FP16_Optimizer(object): ...@@ -570,7 +584,9 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy # a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency. # discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) # Convert output_tensor to float if it's the loss, otherwise stay in half precision.
self.loss_scaler.backward(output_tensor, retain_graph=retain_graph,
output_tensor_grad=output_tensor_grad)
if update_master_grads: if update_master_grads:
self.update_master_grads() self.update_master_grads()
......
...@@ -68,9 +68,13 @@ class LossScaler: ...@@ -68,9 +68,13 @@ class LossScaler:
self.loss_scale) self.loss_scale)
return grad_in return grad_in
def backward(self, loss, retain_graph=False): def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
scaled_loss = loss * self.loss_scale if output_tensor_grad is None:
scaled_loss.backward(retain_graph=retain_graph) scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)
class DynamicLossScaler: class DynamicLossScaler:
...@@ -196,9 +200,13 @@ class DynamicLossScaler: ...@@ -196,9 +200,13 @@ class DynamicLossScaler:
self.loss_scale) self.loss_scale)
return grad_in return grad_in
def backward(self, loss, retain_graph=False): def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
scaled_loss = loss * self.loss_scale if output_tensor_grad is None:
scaled_loss.backward(retain_graph=retain_graph) scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)
############################################################## ##############################################################
......
...@@ -26,7 +26,7 @@ from megatron import get_args ...@@ -26,7 +26,7 @@ from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size from megatron.mpu import set_intra_layer_model_parallel_rank, set_intra_layer_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -65,9 +65,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args.use_cpu_initialization=True args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size) set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized # and return function for external DDP manager to call when it has DDP initialized
set_model_parallel_rank(args.rank) set_intra_layer_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
...@@ -121,12 +121,14 @@ def _initialize_distributed(): ...@@ -121,12 +121,14 @@ def _initialize_distributed():
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
init_method=init_method) init_method=init_method)
# Set the model-parallel / data-parallel communicators. # Set the intra-layer model-parallel, inter-layer model-parallel, and
# data-parallel communicators.
if device_count > 0: if device_count > 0:
if mpu.model_parallel_is_initialized(): if mpu.model_parallel_is_initialized():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.model_parallel_size) mpu.initialize_model_parallel(args.intra_layer_model_parallel_size,
args.inter_layer_model_parallel_size)
def _init_autoresume(): def _init_autoresume():
...@@ -138,14 +140,16 @@ def _init_autoresume(): ...@@ -138,14 +140,16 @@ def _init_autoresume():
torch.distributed.barrier() torch.distributed.barrier()
def _set_random_seed(seed): def _set_random_seed(seed_):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
if seed is not None and seed > 0: if seed_ is not None and seed_ > 0:
# Ensure that different inter-layer MP stages get different seeds.
seed = seed_ + mpu.get_inter_layer_model_parallel_rank()
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed) mpu.intra_layer_model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
from .distributed import * from .distributed import *
from .bert_model import BertModel from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model from .language_model import get_language_model
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.model.language_model import Embedding
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -77,7 +78,7 @@ class BertLMHead(MegatronModule): ...@@ -77,7 +78,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.intra_layer_model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = 1 self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
...@@ -101,17 +102,43 @@ class BertLMHead(MegatronModule): ...@@ -101,17 +102,43 @@ class BertLMHead(MegatronModule):
return output return output
class BertModel(MegatronModule): def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
logit_weights,
fp16_lm_cross_entropy):
# Output.
lm_logits = lm_head(
lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else:
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits
class BertModelBase(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True, def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True): parallel_output=True):
super(BertModel, self).__init__() super(BertModelBase, self).__init__()
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
...@@ -123,52 +150,67 @@ class BertModel(MegatronModule): ...@@ -123,52 +150,67 @@ class BertModel(MegatronModule):
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
self.lm_head = BertLMHead( if mpu.is_inter_layer_last_stage():
self.language_model.embedding.word_embeddings.weight.size(0), if not mpu.is_inter_layer_first_stage():
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) self._word_embeddings_for_head_key = 'word_embeddings_for_head'
self._lm_head_key = 'lm_head' # If first and last stages are different, set word_embeddings
if self.add_binary_head: # weights to 0 here, then copy first stage's weights using all_reduce
self.binary_head = get_linear_layer(args.hidden_size, 2, # below.
init_method) self.word_embeddings = mpu.VocabParallelEmbedding(
self._binary_head_key = 'binary_head' args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
def forward(self, input_ids, attention_mask, self.word_embeddings.weight.data.fill_(0)
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
# Ensure that first and last stages have the same initial embedding weights.
if mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
def word_embeddings_weight(self):
if mpu.is_inter_layer_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_inter_layer_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None): tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
kwargs = {}
if self.add_binary_head: if mpu.is_inter_layer_first_stage():
lm_output, pooled_output = self.language_model( input_ids = bert_model_input
input_ids, position_ids = bert_position_ids(input_ids)
position_ids, args = [input_ids, position_ids, extended_attention_mask]
extended_attention_mask, kwargs['tokentype_ids'] = tokentype_ids
tokentype_ids=tokentype_ids)
else: else:
lm_output = self.language_model( args = [bert_model_input, extended_attention_mask]
input_ids, lm_output = self.language_model(*args, **kwargs)
position_ids, if mpu.is_inter_layer_last_stage() and self.add_binary_head:
extended_attention_mask, lm_output, pooled_output = lm_output
tokentype_ids=tokentype_ids)
# Output.
lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight)
binary_logits = None
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else: else:
if self.fp16_lm_cross_entropy: pooled_output = None
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) if mpu.is_inter_layer_last_stage():
else: return post_language_model_processing(lm_output, pooled_output,
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), self.lm_head, self.binary_head,
lm_labels) lm_labels,
return lm_loss, binary_logits self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
...@@ -180,12 +222,17 @@ class BertModel(MegatronModule): ...@@ -180,12 +222,17 @@ class BertModel(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \ if mpu.is_inter_layer_last_stage():
= self.lm_head.state_dict_for_save_checkpoint( state_dict_[self._lm_head_key] \
destination, prefix, keep_vars) = self.lm_head.state_dict_for_save_checkpoint(
if self.add_binary_head: destination, prefix, keep_vars)
if mpu.is_inter_layer_last_stage() and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -193,8 +240,74 @@ class BertModel(MegatronModule): ...@@ -193,8 +240,74 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict( if mpu.is_inter_layer_last_stage():
state_dict[self._lm_head_key], strict=strict) self.lm_head.load_state_dict(
if self.add_binary_head: state_dict[self._lm_head_key], strict=strict)
if mpu.is_inter_layer_last_stage() and self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
class BertModel(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModel, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
return super(BertModel, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids,
lm_labels=lm_labels)
class BertModelFirstStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(BertModelFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class BertModelIntermediateStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(BertModelIntermediateStage, self).forward(
hidden_state,
attention_mask)
class BertModelLastStage(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask,
lm_labels=None):
return super(BertModelLastStage, self).forward(
hidden_state,
attention_mask,
lm_labels=lm_labels)
...@@ -56,8 +56,7 @@ class Classification(MegatronModule): ...@@ -56,8 +56,7 @@ class Classification(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids, _, pooled_output = self.language_model(input_ids, position_ids,
position_ids,
extended_attention_mask, extended_attention_mask,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
......
...@@ -21,6 +21,7 @@ from megatron import get_args ...@@ -21,6 +21,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import Embedding
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
...@@ -32,11 +33,40 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask): ...@@ -32,11 +33,40 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
return attention_scores return attention_scores
class GPT2Model(MegatronModule): def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy):
if get_key_value:
lm_output, presents = lm_output
# Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
logit_weights,
parallel_output)
if get_key_value:
output = [output, presents]
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
class GPT2ModelBase(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__() super(GPT2ModelBase, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
...@@ -50,43 +80,53 @@ class GPT2Model(MegatronModule): ...@@ -50,43 +80,53 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, labels=None, if mpu.is_inter_layer_last_stage():
if not mpu.is_inter_layer_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
# below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial embedding weights.
if mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
def word_embeddings_weight(self):
if mpu.is_inter_layer_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_inter_layer_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def forward(self, gpt2_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
# Language model. kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
lm_output = self.language_model(input_ids, if mpu.is_inter_layer_first_stage():
position_ids, (input_ids, position_ids) = gpt2_model_input
attention_mask, args = [input_ids, position_ids, attention_mask]
tokentype_ids=tokentype_ids, kwargs['tokentype_ids'] = tokentype_ids
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
lm_output, presents = lm_output
# Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
parallel_output)
if get_key_value:
output = [output, presents]
if labels is None:
return output
else: else:
if self.fp16_lm_cross_entropy: args = [gpt2_model_input, attention_mask]
assert output.dtype == torch.half lm_output = self.language_model(*args, **kwargs)
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else: if mpu.is_inter_layer_last_stage():
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return post_language_model_processing(
return loss lm_output, labels,
self.word_embeddings_weight(),
get_key_value,
self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -95,11 +135,89 @@ class GPT2Model(MegatronModule): ...@@ -95,11 +135,89 @@ class GPT2Model(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
# Load word_embeddings.
if mpu.is_inter_layer_last_stage() and not mpu.is_inter_layer_first_stage():
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key] state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPT2Model(GPT2ModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPT2Model, self).forward(
(input_ids, position_ids),
attention_mask,
labels=labels,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
class GPT2ModelFirstStage(GPT2ModelBase):
def __init__(self, num_tokentypes=0):
super(GPT2ModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPT2ModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
class GPT2ModelIntermediateStage(GPT2ModelBase):
def __init__(self, num_tokentypes=0):
super(GPT2ModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False):
return super(GPT2ModelIntermediateStage, self).forward(
hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
class GPT2ModelLastStage(GPT2ModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPT2ModelLastStage, self).forward(
hidden_state,
attention_mask,
labels=labels,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment