Commit b8940b96 authored by mshoeybi's avatar mshoeybi
Browse files

added for pp

parent 7f2cc3a4
...@@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1
assert args.activations_checkpoint_method is not None, \ assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
assert args.pipeline_model_parallel_size == 1
_print_args(args) _print_args(args)
return args return args
......
...@@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule): ...@@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations, self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.activations_checkpoint_method == 'block':
...@@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule): ...@@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers: if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_checkpointed_activations, self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment