Commit d19e8ad7 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent bbab79f8
...@@ -286,7 +286,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -286,7 +286,7 @@ def parse_args(extra_args_provider=None, defaults={},
'Defaulting to no_persist_layer_norm=True') 'Defaulting to no_persist_layer_norm=True')
# Activation recomputing. # Activation recomputing.
if args.distribute_recomputed_activations: if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \ 'recomputed activations only across tensor model ' \
'parallel groups' 'parallel groups'
...@@ -502,7 +502,7 @@ def _add_training_args(parser): ...@@ -502,7 +502,7 @@ def _add_training_args(parser):
'whole transformer layer is recomputed, ' 'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer ' '2) selective: core attention part of the transformer '
'layer is recomputed.') 'layer is recomputed.')
group.add_argument('--distribute-recomputed-activations', group.add_argument('--distribute-saved-activations',
action='store_true', action='store_true',
help='If set, distribute recomputed activations ' help='If set, distribute recomputed activations '
'across model parallel group.') 'across model parallel group.')
......
...@@ -750,8 +750,8 @@ class ParallelTransformer(MegatronModule): ...@@ -750,8 +750,8 @@ class ParallelTransformer(MegatronModule):
self.recompute_granularity = args.recompute_granularity self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers self.recompute_num_layers = args.recompute_num_layers
self.distribute_recomputed_activations = \ self.distribute_saved_activations = \
args.distribute_recomputed_activations and not args.sequence_parallel args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
...@@ -851,7 +851,7 @@ class ParallelTransformer(MegatronModule): ...@@ -851,7 +851,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_recomputed_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.recompute_num_layers l += self.recompute_num_layers
...@@ -863,7 +863,7 @@ class ParallelTransformer(MegatronModule): ...@@ -863,7 +863,7 @@ class ParallelTransformer(MegatronModule):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_recomputed_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
......
...@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset. tracked/set/reset.
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, distribute_checkpointed_activations, *args): def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_checkpointed_activations \ ctx.distribute_saved_activations \
= distribute_checkpointed_activations = distribute_saved_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if distribute_checkpointed_activations: if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data( safely_set_viewless_tensor_data(
args[0], args[0],
...@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations: if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data( safely_set_viewless_tensor_data(
inputs[0], inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
...@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads return (None, None) + grads
def checkpoint(function, distribute_checkpointed_activations, *args): def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint.""" This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args) distribute_saved_activations, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment