Commit c1e0689d authored by slym's avatar slym
Browse files

Checkpoint a set number of invidividual Transformer layers

consider the case of pipeline-model prallelism

clean up arugments

argument naming cleanup

update readme and examples
parent 68797d90
...@@ -156,7 +156,7 @@ OUTPUT_ARGS="--log-interval 10 \ ...@@ -156,7 +156,7 @@ OUTPUT_ARGS="--log-interval 10 \
--save-interval 500 \ --save-interval 500 \
--eval-interval 100 \ --eval-interval 100 \
--eval-iters 10 \ --eval-iters 10 \
--checkpoint-activations" --activations-checkpoint-method uniform"
python pretrain_bert.py \ python pretrain_bert.py \
$BERT_ARGS \ $BERT_ARGS \
...@@ -345,7 +345,7 @@ python pretrain_ict.py \ ...@@ -345,7 +345,7 @@ python pretrain_ict.py \
--max-position-embeddings 256 \ --max-position-embeddings 256 \
--ict-head-size 128 \ --ict-head-size 128 \
--train-iters 100000 \ --train-iters 100000 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--bert-load /path/to/pretrained_bert \ --bert-load /path/to/pretrained_bert \
--load checkpoints \ --load checkpoints \
--save checkpoints \ --save checkpoints \
...@@ -375,7 +375,7 @@ python tools/create_doc_index.py \ ...@@ -375,7 +375,7 @@ python tools/create_doc_index.py \
--ict-head-size 128 \ --ict-head-size 128 \
--num-attention-heads 12 \ --num-attention-heads 12 \
--batch-size 128 \ --batch-size 128 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--seq-length 256 \ --seq-length 256 \
--max-position-embeddings 256 \ --max-position-embeddings 256 \
--ict-load /path/to/pretrained_ict \ --ict-load /path/to/pretrained_ict \
...@@ -482,7 +482,7 @@ python tasks/main.py \ ...@@ -482,7 +482,7 @@ python tasks/main.py \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--micro-batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 10 \ --log-interval 10 \
--no-load-optim \ --no-load-optim \
--no-load-rng --no-load-rng
...@@ -512,7 +512,7 @@ python tasks/main.py \ ...@@ -512,7 +512,7 @@ python tasks/main.py \
--merge-file $MERGE_FILE \ --merge-file $MERGE_FILE \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--micro-batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 10 \ --log-interval 10 \
--no-load-optim \ --no-load-optim \
--no-load-rng --no-load-rng
...@@ -542,7 +542,7 @@ COMMON_TASK_ARGS="--num-layers 24 \ ...@@ -542,7 +542,7 @@ COMMON_TASK_ARGS="--num-layers 24 \
COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \ COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
--valid-data $VALID_DATA \ --valid-data $VALID_DATA \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \ --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--save-interval 10000 \ --save-interval 10000 \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--log-interval 100 \ --log-interval 100 \
......
...@@ -20,7 +20,7 @@ python tasks/main.py \ ...@@ -20,7 +20,7 @@ python tasks/main.py \
--num-attention-heads 12 \ --num-attention-heads 12 \
--tensor-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--micro-batch-size 128 \ --micro-batch-size 128 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \ --load ${CHECKPOINT_PATH} \
......
...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--log-interval 10 \ --log-interval 10 \
......
...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--micro-batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--lr 5.0e-5 \ --lr 5.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-warmup-fraction 0.065 \ --lr-warmup-fraction 0.065 \
......
...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--micro-batch-size 4 \ --micro-batch-size 4 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--lr 1.0e-5 \ --lr 1.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-warmup-fraction 0.06 \ --lr-warmup-fraction 0.06 \
......
...@@ -33,7 +33,7 @@ python pretrain_gpt.py \ ...@@ -33,7 +33,7 @@ python pretrain_gpt.py \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--lr-warmup-fraction .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -49,7 +49,7 @@ options=" \ ...@@ -49,7 +49,7 @@ options=" \
--init-method-std 0.006 \ --init-method-std 0.006 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \ --tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \ --fp16 \
--checkpoint-activations " --activations-checkpoint-method uniform "
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}" run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"
......
...@@ -40,7 +40,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -40,7 +40,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--lr-warmup-fraction .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -42,7 +42,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -42,7 +42,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--lr-warmup-fraction .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --activations-checkpoint-method uniform \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -91,6 +91,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -91,6 +91,12 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
args.activations_checkpoint_method = 'uniform'
del args.checkpoint_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
...@@ -234,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -234,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
_print_args(args) _print_args(args)
return args return args
...@@ -402,8 +408,19 @@ def _add_training_args(parser): ...@@ -402,8 +408,19 @@ def _add_training_args(parser):
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--activations-checkpoint-method', type=str, default=None,
help='chunk size (number of layers) for checkpointing.') choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) block: checkpoint the input activation of only a set '
'number of individual Transformer layers and skip the rest, '
'default) checkpoint the inputs of every Transformer layer')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
......
...@@ -542,8 +542,8 @@ class ParallelTransformer(MegatronModule): ...@@ -542,8 +542,8 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = None self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.activations_checkpoint_method = args.activations_checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
...@@ -609,12 +609,31 @@ class ParallelTransformer(MegatronModule): ...@@ -609,12 +609,31 @@ class ParallelTransformer(MegatronModule):
# Make sure memory is freed. # Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer() mpu.reset_checkpointed_activations_memory_buffer()
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.activations_checkpoint_num_layers),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states return hidden_states
...@@ -637,7 +656,7 @@ class ParallelTransformer(MegatronModule): ...@@ -637,7 +656,7 @@ class ParallelTransformer(MegatronModule):
'for not None values in layer_past, ' \ 'for not None values in layer_past, ' \
'expected get_key_value to be set' 'expected get_key_value to be set'
if get_key_value: if get_key_value:
assert not self.checkpoint_activations, \ assert self.activations_checkpoint_method is None, \
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
...@@ -656,7 +675,7 @@ class ParallelTransformer(MegatronModule): ...@@ -656,7 +675,7 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
......
...@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer():
per_layer = args.micro_batch_size * args.max_position_embeddings * \ per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \ num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
'number of layers is not divisible by checkpoint-num-layers' if args.virtual_pipeline_model_parallel_size is not None:
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers num_layers = num_layers // args.virtual_pipeline_model_parallel_size
if args.activations_checkpoint_method == 'uniform':
assert num_layers % args.activations_checkpoint_num_layers == 0, \
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
elif args.activations_checkpoint_method == 'block':
assert args.activations_checkpoint_num_layers <= num_layers, \
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers = args.activations_checkpoint_num_layers
numel = per_layer * num_checkpointer_layers numel = per_layer * num_checkpointer_layers
dtype = torch.half dtype = torch.half
if not args.fp16: if not args.fp16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment