Commit bbab79f8 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 356eb36a
...@@ -105,14 +105,19 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -105,14 +105,19 @@ def parse_args(extra_args_provider=None, defaults={},
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations: if args.checkpoint_activations:
args.checkpoint_granularity = 'full' args.recompute_granularity = 'full'
args.checkpoint_method = 'uniform' args.recompute_method = 'uniform'
if args.rank == 0: if args.rank == 0:
print('--checkpoint-activations is no longer valid, ' print('--checkpoint-activations is no longer valid, '
'use --checkpoint-granularity and --checkpoint-method instead. ' 'use --recompute-granularity and --recompute-method instead. '
'Defaulting to checkpoint-granularity=full and checkpoint-method=uniform.') 'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
args.recompute_method = 'uniform'
del args.recompute_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
# For default to be valid, it should not be provided in the # For default to be valid, it should not be provided in the
...@@ -280,26 +285,26 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -280,26 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' 'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True') 'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing. # Activation recomputing.
if args.distribute_checkpointed_activations: if args.distribute_recomputed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \ 'recomputed activations only across tensor model ' \
'parallel groups' 'parallel groups'
assert args.checkpoint_granularity == 'full', \ assert args.recompute_granularity == 'full', \
'distributed checkpoint activations is only '\ 'distributed recompute activations is only '\
'application to full checkpoint granularity' 'application to full recompute granularity'
assert args.checkpoint_method is not None, \ assert args.recompute_method is not None, \
'for distributed checkpoint activations to work you '\ 'for distributed recompute activations to work you '\
'need to use a checkpoint method ' 'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \ 'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.checkpoint_granularity == 'selective': if args.recompute_granularity == 'selective':
assert args.checkpoint_method is None, \ assert args.recompute_method is None, \
'checkpoint method is not yet supported for ' \ 'recompute method is not yet supported for ' \
'selective checkpointing granularity' 'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when # disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled # model parallel memory optimization is enabled
...@@ -486,33 +491,35 @@ def _add_training_args(parser): ...@@ -486,33 +491,35 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase' ' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval' 'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.') 'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--recompute-activations', action='store_true',
group.add_argument('--checkpoint-granularity', type=str, default=None, help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'], choices=['full', 'selective'],
help='Checkpoint activations to allow for training ' help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. ' 'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: ' 'It is supported at two granularities 1) full: '
'whole transformer layer is checkpointed, ' 'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer ' '2) selective: core attention part of the transformer '
'layer is checkpointed.') 'layer is recomputed.')
group.add_argument('--distribute-checkpointed-activations', group.add_argument('--distribute-recomputed-activations',
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute recomputed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--checkpoint-method', type=str, default=None, group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'], choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of ' help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of ' 'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, ' 'each divided chunk at specified granularity, '
'2) checkpoint the input activations of only a set number of ' '2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the ' 'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing at specified granularity' 'rest without any recomputing at specified granularity'
'default) do not apply activations checkpoint to any layers') 'default) do not apply activations recompute to any layers')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each ' help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, ' 'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers ' '2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.') 'to recompute within each pipeline stage.')
# deprecated # deprecated
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
......
...@@ -242,6 +242,10 @@ class CoreAttention(MegatronModule): ...@@ -242,6 +242,10 @@ class CoreAttention(MegatronModule):
output_size[3], output_size[3],
dtype=query_layer.dtype, dtype=query_layer.dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
else:
assert CoreAttention.matmul_input_buffer.size() == \
(output_size[0]*output_size[1], output_size[2], output_size[3]), \
"buffer dimensions should remain the same during the training run"
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
...@@ -358,7 +362,7 @@ class ParallelAttention(MegatronModule): ...@@ -358,7 +362,7 @@ class ParallelAttention(MegatronModule):
self.core_attention = CoreAttention(self.layer_number, self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.checkpoint_granularity == 'selective' self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
...@@ -743,11 +747,11 @@ class ParallelTransformer(MegatronModule): ...@@ -743,11 +747,11 @@ class ParallelTransformer(MegatronModule):
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_granularity = args.checkpoint_granularity self.recompute_granularity = args.recompute_granularity
self.checkpoint_method = args.checkpoint_method self.recompute_method = args.recompute_method
self.checkpoint_num_layers = args.checkpoint_num_layers self.recompute_num_layers = args.recompute_num_layers
self.distribute_checkpointed_activations = \ self.distribute_recomputed_activations = \
args.distribute_checkpointed_activations and not args.sequence_parallel args.distribute_recomputed_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
...@@ -839,33 +843,33 @@ class ParallelTransformer(MegatronModule): ...@@ -839,33 +843,33 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
if self.checkpoint_method == 'uniform': if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk. # the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_checkpointed_activations, self.distribute_recomputed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers l += self.recompute_num_layers
elif self.checkpoint_method == 'block': elif self.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual # Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest. # Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.checkpoint_num_layers: if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_checkpointed_activations, self.distribute_recomputed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
raise ValueError("Invalid activation checkpoint method.") raise ValueError("Invalid activation recompute method.")
return hidden_states return hidden_states
...@@ -886,7 +890,7 @@ class ParallelTransformer(MegatronModule): ...@@ -886,7 +890,7 @@ class ParallelTransformer(MegatronModule):
# Checks. # Checks.
if inference_params: if inference_params:
assert self.checkpoint_granularity is None, \ assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing' 'inference does not work with activation checkpointing'
if not self.pre_process: if not self.pre_process:
...@@ -921,7 +925,7 @@ class ParallelTransformer(MegatronModule): ...@@ -921,7 +925,7 @@ class ParallelTransformer(MegatronModule):
with rng_context: with rng_context:
# Forward pass. # Forward pass.
if self.checkpoint_granularity == 'full': if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
......
...@@ -226,6 +226,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -226,6 +226,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
torch.empty(dim_size, dtype=input.dtype, torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
else:
assert list(LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer.size()) == dim_size, \
"buffer dimensions should remain same during the training run"
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer, LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
input, input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment