"vscode:/vscode.git/clone" did not exist on "2e5e058c17d68ed3202e8beea4b03befa0ca8248"
Commit b8940b96 authored by mshoeybi's avatar mshoeybi
Browse files

added for pp

parent 7f2cc3a4
......@@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
assert args.pipeline_model_parallel_size == 1
_print_args(args)
return args
......
......@@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
......@@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment