Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision. [Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT model with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%. Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT model with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.
...@@ -218,7 +218,12 @@ These scripts use the PyTorch distributed launcher for distributed training. As ...@@ -218,7 +218,12 @@ These scripts use the PyTorch distributed launcher for distributed training. As
The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time. The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.
Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--model-parallel-size` is 1, which will not implement model parallelism. Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each). The number of microbatches in a per-pipeline minibatch is controlled by the `--num-microbatches-in-minibatch` argument. With `WORLD_SIZE` GPUs, `TENSOR_MP_SIZE` tensor-model-parallel size, `PIPELINE_MP_SIZE` pipeline-model-parallel-size, `WORLD_SIZE`/(`TENSOR_MP_SIZE` * `PIPELINE_MP_SIZE`) GPUs will be used for data parallelism. The default values for `--tensor-model-parallel-size` and `--pipeline-model-parallel-size` is 1, which will not implement either form of model parallelism.
We have examples of how to use these two different forms of model parallelism in these scripts:
`bash examples/pretrain_bert_distributed_with_mp.sh`
`bash examples/pretrain_gpt2_distributed_with_mp.sh`
Other than these minor changes, the distributed training is identical to the training on a single GPU. Other than these minor changes, the distributed training is identical to the training on a single GPU.
...@@ -245,7 +250,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \ ...@@ -245,7 +250,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \ --tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch --DDP-impl torch
</pre> </pre>
...@@ -269,7 +274,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \ ...@@ -269,7 +274,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \ --tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch --DDP-impl torch
</pre> </pre>
...@@ -362,14 +367,14 @@ We provide several command line arguments, detailed in the scripts listed below, ...@@ -362,14 +367,14 @@ We provide several command line arguments, detailed in the scripts listed below,
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \ --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
...@@ -488,7 +493,7 @@ Further command line arguments are described in the source file [`main.py`](./ta ...@@ -488,7 +493,7 @@ Further command line arguments are described in the source file [`main.py`](./ta
## BERT Task Evaluation ## BERT Task Evaluation
<a id="race-evaluation"></a> <a id="race-evaluation"></a>
### RACE Evaluation ### RACE Evaluation
The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files. The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line.
<pre> <pre>
TRAIN_DATA="data/RACE/train/middle" TRAIN_DATA="data/RACE/train/middle"
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT \ --load $CHECKPOINT \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--epochs 5 \ --epochs 5 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--epochs 3 \ --epochs 3 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
...@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json ...@@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt MERGE_FILE=gpt2-merges.txt
python tools/generate_samples_gpt2.py \ python tools/generate_samples_gpt2.py \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
......
#!/bin/bash #!/bin/bash
MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \ --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
......
...@@ -32,4 +32,3 @@ python pretrain_bert.py \ ...@@ -32,4 +32,3 @@ python pretrain_bert.py \
--eval-interval 1000 \ --eval-interval 1000 \
--eval-iters 10 \ --eval-iters 10 \
--fp16 --fp16
...@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \ pretrain_bert.py \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
......
#!/bin/bash
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DATA_PATH=<Specify path and file prefix>_text_sentence
CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 2 \
--num-microbatches-in-minibatch 2 \
--seq-length 512 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file bert-vocab.txt \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
...@@ -38,6 +38,3 @@ python pretrain_gpt2.py \ ...@@ -38,6 +38,3 @@ python pretrain_gpt2.py \
--eval-interval 1000 \ --eval-interval 1000 \
--eval-iters 10 \ --eval-iters 10 \
--fp16 --fp16
set +x
...@@ -17,7 +17,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -17,7 +17,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \ pretrain_gpt2.py \
--model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
...@@ -46,7 +46,3 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -46,7 +46,3 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--eval-interval 1000 \ --eval-interval 1000 \
--eval-iters 10 \ --eval-iters 10 \
--fp16 --fp16
set +x
#! /bin/bash
# Runs the "345M" parameter model
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DATA_PATH=<Specify path and file prefix>_text_document
CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--num-microbatches-in-minibatch 2 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters 500000 \
--lr-decay-iters 320000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file gpt2-vocab.json \
--merge-file gpt2-merges.txt \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--checkpoint-activations \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
...@@ -26,6 +26,9 @@ from .package_info import ( ...@@ -26,6 +26,9 @@ from .package_info import (
) )
from .global_vars import get_args from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
...@@ -33,9 +36,21 @@ from .global_vars import get_timers ...@@ -33,9 +36,21 @@ from .global_vars import get_timers
from .initialize import initialize_megatron from .initialize import initialize_megatron
def print_rank_0(message): def print_rank_0(message):
"""If distributed is initialized print only on rank 0.""" """If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(message, flush=True) print(message, flush=True)
else: else:
print(message, flush=True) print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
...@@ -54,10 +54,56 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -54,10 +54,56 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args. # Distributed args.
args.rank = int(os.getenv('RANK', '0')) args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1')) args.world_size = int(os.getenv("WORLD_SIZE", '1'))
args.model_parallel_size = min(args.model_parallel_size, args.world_size) # Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
if args.pipeline_model_parallel_size > 1:
if "ring_exchange" not in dir(torch.distributed):
raise Exception('PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!')
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0: if args.rank == 0:
print('using world size: {} and model-parallel size: {} '.format( print('using world size: {}, data-parallel-size: {}, '
args.world_size, args.model_parallel_size)) 'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
# Fp16 loss scaling. # Fp16 loss scaling.
args.dynamic_loss_scale = False args.dynamic_loss_scale = False
...@@ -90,10 +136,40 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -90,10 +136,40 @@ def parse_args(extra_args_provider=None, defaults={},
else: else:
setattr(args, key, defaults[key]) setattr(args, key, defaults[key])
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-samples'
# Check required arguments. # Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings'] 'max_position_embeddings']
for req_arg in required_args: for req_arg in required_args:
_check_arg_is_not_none(args, req_arg) _check_arg_is_not_none(args, req_arg)
# Checks. # Checks.
...@@ -104,14 +180,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -104,14 +180,6 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.min_lr <= args.lr assert args.min_lr <= args.lr
if args.save is not None: if args.save is not None:
assert args.save_interval is not None assert args.save_interval is not None
# Parameters sharing does not work with torch DDP.
if (args.num_unique_layers is not None) and (args.num_layers is not None):
assert args.num_unique_layers <= args.num_layers
assert args.num_layers % args.num_unique_layers == 0, \
'num-layers should be divisible by num-unique-layers.'
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
# Mixed precision checks. # Mixed precision checks.
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
...@@ -157,16 +225,6 @@ def _add_network_size_args(parser): ...@@ -157,16 +225,6 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None, group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--num-unique-layers', type=int, default=None,
help='Number of unique transformer layers. '
'`num-layers` should be divisible by this value.')
group.add_argument('--param-sharing-style', default='grouped',
choices=['grouped', 'spaced'],
help='Ordering of the shared parameters. For example, '
'for a `num-layers`=4 and `--num-unique-layers`=2, '
'we will have the following ordering for two unique '
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--num-attention-heads', type=int, default=None, group.add_argument('--num-attention-heads', type=int, default=None,
...@@ -197,7 +255,7 @@ def _add_regularization_args(parser): ...@@ -197,7 +255,7 @@ def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization') group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1, group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout ptobability.') help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1, group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.') help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01, group.add_argument('--weight-decay', type=float, default=0.01,
...@@ -220,10 +278,32 @@ def _add_regularization_args(parser): ...@@ -220,10 +278,32 @@ def _add_regularization_args(parser):
def _add_training_args(parser): def _add_training_args(parser):
group = parser.add_argument_group(title='training') group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, default=None, group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training ' help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.') 'with larger models, sequences, and batch sizes.')
...@@ -235,12 +315,19 @@ def _add_training_args(parser): ...@@ -235,12 +315,19 @@ def _add_training_args(parser):
help='chunk size (number of layers) for checkpointing.') help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs.') 'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.') help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None, group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible ' help='Exit the program after the iteration is divisible '
'by this value.') 'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None, group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.') help='Write TensorBoard logs to this directory.')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion', group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
...@@ -285,12 +372,24 @@ def _add_learning_rate_args(parser): ...@@ -285,12 +372,24 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-iters', type=int, default=None, group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,' help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`') ' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup '
'learning rate over.')
group.add_argument('--warmup', type=int, default=None,
help='Old lr warmup argument, do not use. Use one of the '
'--lr-warmup-* arguments above')
group.add_argument('--min-lr', type=float, default=0.0, group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler' help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.') 'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01,
help='Percentage of total iterations to warmup on '
'(.01 = 1 percent of all training iters).')
group.add_argument('--override-lr-scheduler', action='store_true', group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,' help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum ' 'warmup iterations, minimum learning rate, maximum '
...@@ -365,8 +464,13 @@ def _add_mixed_precision_args(parser): ...@@ -365,8 +464,13 @@ def _add_mixed_precision_args(parser):
def _add_distributed_args(parser): def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed') group = parser.add_argument_group(title='distributed')
group.add_argument('--model-parallel-size', type=int, default=1, group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Size of the model parallel.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
...@@ -495,4 +599,3 @@ def _add_realm_args(parser): ...@@ -495,4 +599,3 @@ def _add_realm_args(parser):
group.add_argument('--indexer-log-interval', type=int, default=1000, group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress') help='After how many batches should the indexer report progress')
return parser return parser
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import mpu, get_args from megatron import mpu, get_args, update_num_microbatches
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -41,11 +41,14 @@ def get_checkpoint_version(): ...@@ -41,11 +41,14 @@ def get_checkpoint_version():
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input """Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint.""" arguments and the one retrieved from checkpoint."""
args = get_args() args = get_args()
def _compare(arg_name): def _compare(arg_name, old_arg_name=None):
checkpoint_value = getattr(checkpoint_args, arg_name) if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name)
else:
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name) args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \ error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format( 'input argument value ({}).'.format(
...@@ -59,7 +62,12 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,7 +62,12 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
_compare('model_parallel_size') if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
if get_checkpoint_version() >= 3.0:
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename): def ensure_directory_exists(filename):
...@@ -70,16 +78,22 @@ def ensure_directory_exists(filename): ...@@ -70,16 +78,22 @@ def ensure_directory_exists(filename):
def get_checkpoint_name(checkpoints_path, iteration, def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None): release=False):
"""A unified checkpoint name.""" """A unified checkpoint name."""
if release: if release:
directory = 'release' directory = 'release'
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format( 'mp_rank_{:02d}_{:03d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None mpu.get_tensor_model_parallel_rank(),
else mp_rank), mpu.get_pipeline_model_parallel_rank()),
'model_optim_rng.pt') 'model_optim_rng.pt')
...@@ -96,12 +110,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -96,12 +110,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
if torch.distributed.get_rank() == 0:
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() state_dict['model'] = model.state_dict_for_save_checkpoint()
...@@ -123,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -123,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
ensure_directory_exists(checkpoint_name) ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
...@@ -178,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -178,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0: if torch.distributed.get_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print(' loading checkpoint from {} at iteration {}'.format(
torch.distributed.get_rank(), checkpoint_name)) args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
...@@ -222,6 +241,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -222,6 +241,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
check_checkpoint_args(checkpoint_args) check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args, args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0) 'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args, args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0) 'consumed_valid_samples', 0)
else: else:
...@@ -261,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -261,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
sys.exit() sys.exit()
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if torch.distributed.get_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
return iteration return iteration
......
...@@ -153,8 +153,10 @@ def get_samples_mapping_(indexed_dataset, ...@@ -153,8 +153,10 @@ def get_samples_mapping_(indexed_dataset,
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size( torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
group=mpu.get_data_parallel_group()) assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset. # Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
......
...@@ -29,16 +29,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -29,16 +29,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
return None return None
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * world_size
# Megatron sampler # Megatron sampler
batch_sampler = MegatronPretrainingSampler( batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset), total_samples=len(dataset),
consumed_samples=consumed_samples, consumed_samples=consumed_samples,
global_batch_size=global_batch_size, micro_batch_size=args.micro_batch_size,
rank=mpu.get_data_parallel_rank(), data_parallel_rank=mpu.get_data_parallel_rank(),
world_size=world_size) data_parallel_size=mpu.get_data_parallel_world_size())
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
...@@ -50,13 +47,15 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -50,13 +47,15 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, def __init__(self, total_samples, consumed_samples, micro_batch_size,
global_batch_size, rank, world_size): data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
self.global_batch_size = global_batch_size self.micro_batch_size = micro_batch_size
self.rank = rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
data_parallel_size
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
...@@ -64,19 +63,11 @@ class MegatronPretrainingSampler: ...@@ -64,19 +63,11 @@ class MegatronPretrainingSampler:
assert self.consumed_samples < self.total_samples, \ assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples, 'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples) self.total_samples)
assert self.global_batch_size > 0, \ assert self.micro_batch_size > 0
'Unexpected global batch size: {}'.format(self.global_batch_size) assert data_parallel_size > 0
assert world_size > 0,\ assert self.data_parallel_rank < data_parallel_size, \
'non zero world size is expected: {}'.format(world_size) 'data_parallel_rank should be smaller than data size: {}, ' \
assert self.rank < world_size,\ '{}'.format(self.data_parallel_rank, data_parallel_size)
'rank should be smaller than world size: {}, {}'.format(
self.rank, world_size)
# Batch size per rank.
assert self.global_batch_size % world_size == 0,\
'global batch size must be divisible by world size: {}, {}'.format(
self.global_batch_size, world_size)
self.batch_size_per_rank = self.global_batch_size // world_size
def __len__(self): def __len__(self):
...@@ -88,8 +79,8 @@ class MegatronPretrainingSampler: ...@@ -88,8 +79,8 @@ class MegatronPretrainingSampler:
# Last batch if not complete will be dropped. # Last batch if not complete will be dropped.
for idx in range(self.consumed_samples, self.total_samples): for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx) batch.append(idx)
if len(batch) == self.global_batch_size: if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx = self.rank * self.batch_size_per_rank start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.batch_size_per_rank end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
...@@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, dataset_type=dataset_type)
if train_ds:
# Blend. train_datasets.append(train_ds)
blending_train_dataset = BlendableDataset(train_datasets, weights) if valid_ds:
blending_valid_dataset = BlendableDataset(valid_datasets, weights) valid_datasets.append(valid_ds)
blending_test_dataset = BlendableDataset(test_datasets, weights) if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
......
...@@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup) seq_length, seed, skip_warmup)
train_datasets.append(train_ds) if train_ds:
valid_datasets.append(valid_ds) train_datasets.append(train_ds)
test_datasets.append(test_ds) if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend. # Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights) blending_train_dataset = None
blending_valid_dataset = BlendableDataset(valid_datasets, weights) if train_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights) blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
...@@ -210,9 +219,49 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -210,9 +219,49 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
print_rank_0(' > WARNING: could not find index map files, building ' print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...') 'the indices on rank 0 ...')
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# If we need only one epoch, then separating last epoch does
# not mean anything.
if num_epochs == 1:
separate_last_epoch = False
print(' > only one epoch required, setting '
'separate_last_epoch to False', flush=True)
else:
# Get the number of samples for the last epoch
num_samples_from_epochs_minus_one = (
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
last_epoch_num_samples = num_samples - \
num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, \
'last epoch number of samples should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
'last epoch number of samples exceeded max value.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
# Note: the 80% number is just based on common sense and can
# be adjusted if needed.
separate_last_epoch = (last_epoch_num_samples <
int(0.80 * num_samples_per_epoch))
if separate_last_epoch:
string = ' > last epoch number of samples ({}) is smaller '\
'than 80% of number of samples per epoch ({}), '\
'setting separate_last_epoch to True'
else:
string = ' > last epoch number of samples ({}) is larger '\
'than 80% of number of samples per epoch ({}), '\
'setting separate_last_epoch to False'
print(string.format(last_epoch_num_samples,
num_samples_per_epoch), flush=True)
# doc-idx. # doc-idx.
start_time = time.time() start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng) doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True) np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save doc-idx mapping ' print_rank_0(' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time)) '(seconds): {:4f}'.format(time.time() - start_time))
...@@ -236,7 +285,12 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -236,7 +285,12 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time() start_time = time.time()
# -1 is due to data structure used to retieve the index: # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1]) # sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) if separate_last_epoch:
num_samples_ = num_samples_from_epochs_minus_one
else:
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_,
sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping' print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time)) ' (seconds): {:4f}'.format(time.time() - start_time))
...@@ -246,8 +300,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -246,8 +300,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size( torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
group=mpu.get_data_parallel_group()) assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load mappings. # Load mappings.
start_time = time.time() start_time = time.time()
...@@ -289,15 +345,20 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples): ...@@ -289,15 +345,20 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples):
return num_epochs return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng): def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-dcuments. """Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document.""" Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] if not separate_last_epoch or num_epochs == 1:
doc_idx[:] = documents doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx = doc_idx.reshape(-1) doc_idx[:] = documents
doc_idx = doc_idx.astype(np.int32) doc_idx = doc_idx.reshape(-1)
np_rng.shuffle(doc_idx) doc_idx = doc_idx.astype(np.int32)
return doc_idx np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))
def _build_sample_idx(sizes, doc_idx, seq_length, def _build_sample_idx(sizes, doc_idx, seq_length,
...@@ -349,11 +410,23 @@ def _build_sample_idx(sizes, doc_idx, seq_length, ...@@ -349,11 +410,23 @@ def _build_sample_idx(sizes, doc_idx, seq_length,
return sample_idx return sample_idx
def _build_shuffle_idx(size, np_rng): def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle.""" """Build the range [0, size) and shuffle."""
print(' > building shuffle index with split [0, {}) and [{}, {}) '
'...'.format(num_samples, num_samples, total_size), flush=True)
dtype_ = np.uint32 dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1): if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64 dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx) shuffle_idx_first = np.arange(start=0, stop=num_samples,
return shuffle_idx step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = np.arange(start=num_samples, stop=total_size,
step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
...@@ -9,15 +9,15 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co ...@@ -9,15 +9,15 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, batch_size=None): def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job.""" """Specifically one epoch to be used in an indexing job."""
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank() rank = mpu.get_data_parallel_rank()
if batch_size is None: if micro_batch_size is None:
batch_size = args.batch_size micro_batch_size = args.micro_batch_size
global_batch_size = batch_size * world_size global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment