Commit c61dc22f authored by mshoeybi's avatar mshoeybi
Browse files

some cleanup

parent b8940b96
...@@ -240,10 +240,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -240,10 +240,12 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1 assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \ assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' 'need to use a activation-checkpoint method '
_print_args(args) _print_args(args)
return args return args
......
...@@ -608,6 +608,23 @@ class ParallelTransformer(MegatronModule): ...@@ -608,6 +608,23 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
def distribute_checkpointed_activations_helper(layer_number):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = (layer_number > 0)
is_first_pipeline_stage = (
mpu.get_pipeline_model_parallel_rank() == 0)
return self.distribute_checkpointed_activations and \
(not_first_layer_in_pipeline_stage or is_first_pipeline_stage)
if self.activations_checkpoint_method == 'uniform': if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk. # the input activation of each divided chunk.
...@@ -616,7 +633,7 @@ class ParallelTransformer(MegatronModule): ...@@ -616,7 +633,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)), distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.activations_checkpoint_method == 'block':
...@@ -627,7 +644,7 @@ class ParallelTransformer(MegatronModule): ...@@ -627,7 +644,7 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers: if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)), distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment