Commit 23266c57 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'slym/partial_chpt' into 'main'

Checkpoint a set number of individual Transformer layers

See merge request ADLR/megatron-lm!301
parents 3715b0bc 99f47676
...@@ -156,7 +156,7 @@ OUTPUT_ARGS="--log-interval 10 \ ...@@ -156,7 +156,7 @@ OUTPUT_ARGS="--log-interval 10 \
--save-interval 500 \ --save-interval 500 \
--eval-interval 100 \ --eval-interval 100 \
--eval-iters 10 \ --eval-iters 10 \
--checkpoint-activations" --activations-checkpoint-method uniform"
python pretrain_bert.py \ python pretrain_bert.py \
$BERT_ARGS \ $BERT_ARGS \
...@@ -302,6 +302,15 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_<model>.py \ ...@@ -302,6 +302,15 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_<model>.py \
The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)) can be enabled using the `--num-layers-per-virtual-pipeline-stage` argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage with `NUM_LAYERS / PIPELINE_MP_SIZE` transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed as `GLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)`) should be divisible by the `PIPELINE_MP_SIZE` when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (`PIPELINE_MP_SIZE=2`). The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)) can be enabled using the `--num-layers-per-virtual-pipeline-stage` argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage with `NUM_LAYERS / PIPELINE_MP_SIZE` transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed as `GLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)`) should be divisible by the `PIPELINE_MP_SIZE` when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (`PIPELINE_MP_SIZE=2`).
## Activation Checkpointing and Recomputation
To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We use a Transformer layer as the unit of checkpointing because the activation size bloats in the middle of a Transformer layer so checkpointing the input of a Transformer layer is storage-efficient. We support two activation checkpointing methods: `uniform` and `block`.
Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed.
Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop.
## GPT-3 Example ## GPT-3 Example
In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incrmeental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights. In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incrmeental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights.
...@@ -345,7 +354,7 @@ python pretrain_ict.py \ ...@@ -345,7 +354,7 @@ python pretrain_ict.py \
--max-position-embeddings 256 \ --max-position-embeddings 256 \
--ict-head-size 128 \ --ict-head-size 128 \
--train-iters 100000 \ --train-iters 100000 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--bert-load /path/to/pretrained_bert \ --bert-load /path/to/pretrained_bert \
--load checkpoints \ --load checkpoints \
--save checkpoints \ --save checkpoints \
...@@ -375,7 +384,7 @@ python tools/create_doc_index.py \ ...@@ -375,7 +384,7 @@ python tools/create_doc_index.py \
--ict-head-size 128 \ --ict-head-size 128 \
--num-attention-heads 12 \ --num-attention-heads 12 \
--batch-size 128 \ --batch-size 128 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--seq-length 256 \ --seq-length 256 \
--max-position-embeddings 256 \ --max-position-embeddings 256 \
--ict-load /path/to/pretrained_ict \ --ict-load /path/to/pretrained_ict \
...@@ -482,7 +491,7 @@ python tasks/main.py \ ...@@ -482,7 +491,7 @@ python tasks/main.py \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--micro-batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 10 \ --log-interval 10 \
--no-load-optim \ --no-load-optim \
--no-load-rng --no-load-rng
...@@ -512,7 +521,7 @@ python tasks/main.py \ ...@@ -512,7 +521,7 @@ python tasks/main.py \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--micro-batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 10 \ --log-interval 10 \
--no-load-optim \ --no-load-optim \
--no-load-rng --no-load-rng
...@@ -542,7 +551,7 @@ COMMON_TASK_ARGS="--num-layers 24 \ ...@@ -542,7 +551,7 @@ COMMON_TASK_ARGS="--num-layers 24 \
COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \ COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
--valid-data $VALID_DATA \ --valid-data $VALID_DATA \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--save-interval 10000 \ --save-interval 10000 \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--log-interval 100 \ --log-interval 100 \
......
...@@ -20,7 +20,7 @@ python tasks/main.py \ ...@@ -20,7 +20,7 @@ python tasks/main.py \
--num-attention-heads 12 \ --num-attention-heads 12 \
--tensor-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--micro-batch-size 128 \ --micro-batch-size 128 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \ --load ${CHECKPOINT_PATH} \
......
...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--log-interval 10 \ --log-interval 10 \
......
...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--micro-batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--lr 5.0e-5 \ --lr 5.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-warmup-fraction 0.065 \ --lr-warmup-fraction 0.065 \
......
...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--micro-batch-size 4 \ --micro-batch-size 4 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--lr 1.0e-5 \ --lr 1.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-warmup-fraction 0.06 \ --lr-warmup-fraction 0.06 \
......
...@@ -33,7 +33,7 @@ python pretrain_gpt.py \ ...@@ -33,7 +33,7 @@ python pretrain_gpt.py \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--lr-warmup-fraction .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -49,7 +49,7 @@ options=" \ ...@@ -49,7 +49,7 @@ options=" \
--init-method-std 0.006 \ --init-method-std 0.006 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \ --tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \ --fp16 \
--checkpoint-activations " --activations-checkpoint-method uniform "
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}" run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"
......
...@@ -40,7 +40,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -40,7 +40,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--lr-warmup-fraction .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -42,7 +42,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -42,7 +42,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--lr-warmup-fraction .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -25,7 +25,7 @@ MBS=1 ...@@ -25,7 +25,7 @@ MBS=1
HS=20480 HS=20480
NAH=128 NAH=128
DDP=local DDP=local
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
# Name of the job. # Name of the job.
......
...@@ -16,9 +16,9 @@ GBS=12 ...@@ -16,9 +16,9 @@ GBS=12
# Set interleaved schedule options. # Set interleaved schedule options.
if [ ${INTERLEAVED} == "YES" ]; then if [ ${INTERLEAVED} == "YES" ]; then
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 2 " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 "
elif [ ${INTERLEAVED} == "NO" ]; then elif [ ${INTERLEAVED} == "NO" ]; then
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
else else
echo "Invalid configuration" echo "Invalid configuration"
exit 1 exit 1
......
...@@ -24,7 +24,7 @@ NLS=32 ...@@ -24,7 +24,7 @@ NLS=32
HS=20480 HS=20480
NAH=128 NAH=128
DDP=local DDP=local
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
NNODES=8 NNODES=8
......
...@@ -25,7 +25,7 @@ NLS=32 ...@@ -25,7 +25,7 @@ NLS=32
HS=3840 HS=3840
NAH=32 NAH=32
DDP=local DDP=local
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
NNODES=8 NNODES=8
......
...@@ -25,7 +25,7 @@ NLS=32 ...@@ -25,7 +25,7 @@ NLS=32
HS=3840 HS=3840
NAH=32 NAH=32
DDP=local DDP=local
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
NNODES=8 NNODES=8
......
...@@ -21,7 +21,7 @@ NLS=32 ...@@ -21,7 +21,7 @@ NLS=32
HS=15360 HS=15360
NAH=128 NAH=128
DDP=local DDP=local
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
NNODES=8 NNODES=8
......
...@@ -16,7 +16,7 @@ GBS=1 ...@@ -16,7 +16,7 @@ GBS=1
# Set activation recomputation. # Set activation recomputation.
if [ ${ACTIVATION_RECOMPUTATION} == "YES" ]; then if [ ${ACTIVATION_RECOMPUTATION} == "YES" ]; then
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
elif [ ${ACTIVATION_RECOMPUTATION} == "NO" ]; then elif [ ${ACTIVATION_RECOMPUTATION} == "NO" ]; then
MEGATRON_EXTRA_PARAMS="" MEGATRON_EXTRA_PARAMS=""
else else
......
...@@ -16,9 +16,9 @@ GBS=12 ...@@ -16,9 +16,9 @@ GBS=12
# Set scatter-gather communication optimization options. # Set scatter-gather communication optimization options.
if [ ${SCATTER_GATHER} == "YES" ]; then if [ ${SCATTER_GATHER} == "YES" ]; then
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 2 " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 "
elif [ ${SCATTER_GATHER} == "NO" ]; then elif [ ${SCATTER_GATHER} == "NO" ]; then
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 2 --no-scatter-gather-tensors-in-pipeline " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 --no-scatter-gather-tensors-in-pipeline "
else else
echo "Invalid configuration" echo "Invalid configuration"
exit 1 exit 1
......
...@@ -21,7 +21,7 @@ if [ ${MODEL_SIZE} == "1.7B" ]; then ...@@ -21,7 +21,7 @@ if [ ${MODEL_SIZE} == "1.7B" ]; then
NAH=24 NAH=24
DDP=torch DDP=torch
NNODES=4 NNODES=4
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
elif [ ${MODEL_SIZE} == "3.6B" ]; then elif [ ${MODEL_SIZE} == "3.6B" ]; then
TP=2 TP=2
PP=1 PP=1
...@@ -32,7 +32,7 @@ elif [ ${MODEL_SIZE} == "3.6B" ]; then ...@@ -32,7 +32,7 @@ elif [ ${MODEL_SIZE} == "3.6B" ]; then
NAH=32 NAH=32
DDP=torch DDP=torch
NNODES=8 NNODES=8
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
elif [ ${MODEL_SIZE} == "7.5B" ]; then elif [ ${MODEL_SIZE} == "7.5B" ]; then
TP=4 TP=4
PP=1 PP=1
...@@ -43,7 +43,7 @@ elif [ ${MODEL_SIZE} == "7.5B" ]; then ...@@ -43,7 +43,7 @@ elif [ ${MODEL_SIZE} == "7.5B" ]; then
NAH=32 NAH=32
DDP=torch DDP=torch
NNODES=16 NNODES=16
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
elif [ ${MODEL_SIZE} == "18B" ]; then elif [ ${MODEL_SIZE} == "18B" ]; then
TP=8 TP=8
PP=1 PP=1
...@@ -54,7 +54,7 @@ elif [ ${MODEL_SIZE} == "18B" ]; then ...@@ -54,7 +54,7 @@ elif [ ${MODEL_SIZE} == "18B" ]; then
NAH=48 NAH=48
DDP=torch DDP=torch
NNODES=32 NNODES=32
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
elif [ ${MODEL_SIZE} == "39B" ]; then elif [ ${MODEL_SIZE} == "39B" ]; then
TP=8 TP=8
PP=2 PP=2
...@@ -65,7 +65,7 @@ elif [ ${MODEL_SIZE} == "39B" ]; then ...@@ -65,7 +65,7 @@ elif [ ${MODEL_SIZE} == "39B" ]; then
NAH=64 NAH=64
DDP=local DDP=local
NNODES=64 NNODES=64
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
elif [ ${MODEL_SIZE} == "76B" ]; then elif [ ${MODEL_SIZE} == "76B" ]; then
TP=8 TP=8
PP=4 PP=4
...@@ -76,7 +76,7 @@ elif [ ${MODEL_SIZE} == "76B" ]; then ...@@ -76,7 +76,7 @@ elif [ ${MODEL_SIZE} == "76B" ]; then
NAH=80 NAH=80
DDP=local DDP=local
NNODES=128 NNODES=128
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 5" MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5"
elif [ ${MODEL_SIZE} == "145B" ]; then elif [ ${MODEL_SIZE} == "145B" ]; then
TP=8 TP=8
PP=8 PP=8
...@@ -87,7 +87,7 @@ elif [ ${MODEL_SIZE} == "145B" ]; then ...@@ -87,7 +87,7 @@ elif [ ${MODEL_SIZE} == "145B" ]; then
NAH=96 NAH=96
DDP=local DDP=local
NNODES=192 NNODES=192
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 5 " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5 "
elif [ ${MODEL_SIZE} == "310B" ]; then elif [ ${MODEL_SIZE} == "310B" ]; then
TP=8 TP=8
PP=16 PP=16
...@@ -98,7 +98,7 @@ elif [ ${MODEL_SIZE} == "310B" ]; then ...@@ -98,7 +98,7 @@ elif [ ${MODEL_SIZE} == "310B" ]; then
NAH=128 NAH=128
DDP=local DDP=local
NNODES=240 NNODES=240
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 3 " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 3 "
elif [ ${MODEL_SIZE} == "530B" ]; then elif [ ${MODEL_SIZE} == "530B" ]; then
TP=8 TP=8
PP=35 PP=35
...@@ -109,7 +109,7 @@ elif [ ${MODEL_SIZE} == "530B" ]; then ...@@ -109,7 +109,7 @@ elif [ ${MODEL_SIZE} == "530B" ]; then
NAH=128 NAH=128
DDP=local DDP=local
NNODES=315 NNODES=315
MEGATRON_EXTRA_PARAMS="--checkpoint-activations --num-layers-per-virtual-pipeline-stage 1 " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 1 "
elif [ ${MODEL_SIZE} == "1T" ]; then elif [ ${MODEL_SIZE} == "1T" ]; then
TP=8 TP=8
PP=64 PP=64
...@@ -120,7 +120,7 @@ elif [ ${MODEL_SIZE} == "1T" ]; then ...@@ -120,7 +120,7 @@ elif [ ${MODEL_SIZE} == "1T" ]; then
NAH=160 NAH=160
DDP=local DDP=local
NNODES=384 NNODES=384
MEGATRON_EXTRA_PARAMS="--checkpoint-activations " MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform "
else else
echo "Invalid configuration" echo "Invalid configuration"
exit 1 exit 1
......
...@@ -91,6 +91,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -91,6 +91,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
del args.checkpoint_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
...@@ -233,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -233,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
_print_args(args) _print_args(args)
return args return args
...@@ -401,8 +408,20 @@ def _add_training_args(parser): ...@@ -401,8 +408,20 @@ def _add_training_args(parser):
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--activations-checkpoint-method', type=str, default=None,
help='chunk size (number of layers) for checkpointing.') choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
......
...@@ -542,8 +542,8 @@ class ParallelTransformer(MegatronModule): ...@@ -542,8 +542,8 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = None self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.activations_checkpoint_method = args.activations_checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
...@@ -609,12 +609,31 @@ class ParallelTransformer(MegatronModule): ...@@ -609,12 +609,31 @@ class ParallelTransformer(MegatronModule):
# Make sure memory is freed. # Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer() mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers: if self.activations_checkpoint_method == 'uniform':
hidden_states = mpu.checkpoint( # Uniformly divide the total number of Transformer layers and checkpoint
custom(l, l + self.checkpoint_num_layers), # the input activation of each divided chunk.
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) # A method to further reduce memory usage reducing checkpoints.
l += self.checkpoint_num_layers l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states return hidden_states
...@@ -637,7 +656,7 @@ class ParallelTransformer(MegatronModule): ...@@ -637,7 +656,7 @@ class ParallelTransformer(MegatronModule):
'for not None values in layer_past, ' \ 'for not None values in layer_past, ' \
'expected get_key_value to be set' 'expected get_key_value to be set'
if get_key_value: if get_key_value:
assert not self.checkpoint_activations, \ assert self.activations_checkpoint_method is None, \
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
...@@ -656,7 +675,7 @@ class ParallelTransformer(MegatronModule): ...@@ -656,7 +675,7 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment