Commit d19e8ad7 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent bbab79f8
......@@ -286,7 +286,7 @@ def parse_args(extra_args_provider=None, defaults={},
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_recomputed_activations:
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups'
......@@ -502,7 +502,7 @@ def _add_training_args(parser):
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-recomputed-activations',
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
'across model parallel group.')
......
......@@ -750,8 +750,8 @@ class ParallelTransformer(MegatronModule):
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_recomputed_activations = \
args.distribute_recomputed_activations and not args.sequence_parallel
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
......@@ -851,7 +851,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_recomputed_activations,
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.recompute_num_layers
......@@ -863,7 +863,7 @@ class ParallelTransformer(MegatronModule):
if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_recomputed_activations,
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
......
......@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_checkpointed_activations, *args):
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
ctx.distribute_saved_activations \
= distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_checkpointed_activations:
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
......@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations:
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
......@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads
def checkpoint(function, distribute_checkpointed_activations, *args):
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
distribute_saved_activations, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment