Commit bbab79f8 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 356eb36a
......@@ -105,14 +105,19 @@ def parse_args(extra_args_provider=None, defaults={},
del args.model_parallel_size
if args.checkpoint_activations:
args.checkpoint_granularity = 'full'
args.checkpoint_method = 'uniform'
args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --checkpoint-granularity and --checkpoint-method instead. '
'Defaulting to checkpoint-granularity=full and checkpoint-method=uniform.')
'use --recompute-granularity and --recompute-method instead. '
'Defaulting to recompute-granularity=full and recompute-method=uniform.')
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
args.recompute_method = 'uniform'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
......@@ -280,26 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing.
if args.distribute_checkpointed_activations:
# Activation recomputing.
if args.distribute_recomputed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.checkpoint_granularity == 'full', \
'distributed checkpoint activations is only '\
'application to full checkpoint granularity'
assert args.checkpoint_method is not None, \
'for distributed checkpoint activations to work you '\
'need to use a checkpoint method '
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.checkpoint_granularity == 'selective':
assert args.checkpoint_method is None, \
'checkpoint method is not yet supported for ' \
'selective checkpointing granularity'
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
......@@ -486,33 +491,35 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-granularity', type=str, default=None,
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is checkpointed, '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is checkpointed.')
group.add_argument('--distribute-checkpointed-activations',
'layer is recomputed.')
group.add_argument('--distribute-recomputed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--checkpoint-method', type=str, default=None,
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) checkpoint the input activations of only a set number of '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing at specified granularity'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
'to recompute within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
......
......@@ -242,6 +242,10 @@ class CoreAttention(MegatronModule):
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
else:
assert CoreAttention.matmul_input_buffer.size() == \
(output_size[0]*output_size[1], output_size[2], output_size[3]), \
"buffer dimensions should remain the same during the training run"
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
......@@ -358,7 +362,7 @@ class ParallelAttention(MegatronModule):
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
self.checkpoint_core_attention = args.checkpoint_granularity == 'selective'
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output.
self.dense = mpu.RowParallelLinear(
......@@ -743,11 +747,11 @@ class ParallelTransformer(MegatronModule):
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.checkpoint_granularity = args.checkpoint_granularity
self.checkpoint_method = args.checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers
self.distribute_checkpointed_activations = \
args.distribute_checkpointed_activations and not args.sequence_parallel
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_recomputed_activations = \
args.distribute_recomputed_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
......@@ -839,33 +843,33 @@ class ParallelTransformer(MegatronModule):
return x_
return custom_forward
if self.checkpoint_method == 'uniform':
if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
self.distribute_checkpointed_activations,
custom(l, l + self.recompute_num_layers),
self.distribute_recomputed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
l += self.recompute_num_layers
elif self.checkpoint_method == 'block':
elif self.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.checkpoint_num_layers:
if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
self.distribute_recomputed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
raise ValueError("Invalid activation recompute method.")
return hidden_states
......@@ -886,7 +890,7 @@ class ParallelTransformer(MegatronModule):
# Checks.
if inference_params:
assert self.checkpoint_granularity is None, \
assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing'
if not self.pre_process:
......@@ -921,7 +925,7 @@ class ParallelTransformer(MegatronModule):
with rng_context:
# Forward pass.
if self.checkpoint_granularity == 'full':
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
......
......@@ -226,6 +226,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
else:
assert list(LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer.size()) == dim_size, \
"buffer dimensions should remain same during the training run"
torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment