"vscode:/vscode.git/clone" did not exist on "169d5169fe4f805f39eef4a5b0aa2fe480190afe"
Commit c1e0689d authored by slym's avatar slym
Browse files

Checkpoint a set number of invidividual Transformer layers

consider the case of pipeline-model prallelism

clean up arugments

argument naming cleanup

update readme and examples
parent 68797d90
......@@ -156,7 +156,7 @@ OUTPUT_ARGS="--log-interval 10 \
--save-interval 500 \
--eval-interval 100 \
--eval-iters 10 \
--checkpoint-activations"
--activations-checkpoint-method uniform"
python pretrain_bert.py \
$BERT_ARGS \
......@@ -345,7 +345,7 @@ python pretrain_ict.py \
--max-position-embeddings 256 \
--ict-head-size 128 \
--train-iters 100000 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--bert-load /path/to/pretrained_bert \
--load checkpoints \
--save checkpoints \
......@@ -375,7 +375,7 @@ python tools/create_doc_index.py \
--ict-head-size 128 \
--num-attention-heads 12 \
--batch-size 128 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--seq-length 256 \
--max-position-embeddings 256 \
--ict-load /path/to/pretrained_ict \
......@@ -482,7 +482,7 @@ python tasks/main.py \
--merge-file $MERGE_FILE \
--load $CHECKPOINT_PATH \
--micro-batch-size 8 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--log-interval 10 \
--no-load-optim \
--no-load-rng
......@@ -512,7 +512,7 @@ python tasks/main.py \
--merge-file $MERGE_FILE \
--load $CHECKPOINT_PATH \
--micro-batch-size 8 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--log-interval 10 \
--no-load-optim \
--no-load-rng
......@@ -542,7 +542,7 @@ COMMON_TASK_ARGS="--num-layers 24 \
COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
--valid-data $VALID_DATA \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--save-interval 10000 \
--save $CHECKPOINT_PATH \
--log-interval 100 \
......
......@@ -20,7 +20,7 @@ python tasks/main.py \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--seq-length 512 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
......
......@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 8 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--seq-length 1024 \
--max-position-embeddings 1024 \
--log-interval 10 \
......
......@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size 8 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--lr 5.0e-5 \
--lr-decay-style linear \
--lr-warmup-fraction 0.065 \
......
......@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size 4 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--lr 1.0e-5 \
--lr-decay-style linear \
--lr-warmup-fraction 0.06 \
......
......@@ -33,7 +33,7 @@ python pretrain_gpt.py \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
......
......@@ -49,7 +49,7 @@ options=" \
--init-method-std 0.006 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \
--checkpoint-activations "
--activations-checkpoint-method uniform "
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"
......
......@@ -40,7 +40,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
......
......@@ -42,7 +42,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--checkpoint-activations \
--activations-checkpoint-method uniform \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
......
......@@ -91,6 +91,12 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
args.activations_checkpoint_method = 'uniform'
del args.checkpoint_activations
# Set input defaults.
for key in defaults:
......@@ -234,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
_print_args(args)
return args
......@@ -402,8 +408,19 @@ def _add_training_args(parser):
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) block: checkpoint the input activation of only a set '
'number of individual Transformer layers and skip the rest, '
'default) checkpoint the inputs of every Transformer layer')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......
......@@ -542,8 +542,8 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = None
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
# Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
......@@ -609,12 +609,31 @@ class ParallelTransformer(MegatronModule):
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
......@@ -637,7 +656,7 @@ class ParallelTransformer(MegatronModule):
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
assert self.activations_checkpoint_method is None, \
'get_key_value does not work with ' \
'activation checkpointing'
......@@ -656,7 +675,7 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations:
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
......
......@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer():
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
if args.virtual_pipeline_model_parallel_size is not None:
num_layers = num_layers // args.virtual_pipeline_model_parallel_size
if args.activations_checkpoint_method == 'uniform':
assert num_layers % args.activations_checkpoint_num_layers == 0, \
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
elif args.activations_checkpoint_method == 'block':
assert args.activations_checkpoint_num_layers <= num_layers, \
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers = args.activations_checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment