Commit 60750922 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/dst-chk-act' into 'main'

Fixed distributed checkpoint activations to work on all layers

See merge request ADLR/megatron-lm!365
parents 2ad00f4e 1b28a51b
...@@ -245,17 +245,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -245,17 +245,6 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp32_residual_connection: if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
...@@ -267,6 +256,19 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -267,6 +256,19 @@ def parse_args(extra_args_provider=None, defaults={},
'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' 'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True') 'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distributed checkpoint activations to work you '\
'need to use a activation-checkpoint method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
_print_args(args) _print_args(args)
return args return args
......
...@@ -626,23 +626,6 @@ class ParallelTransformer(MegatronModule): ...@@ -626,23 +626,6 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
def distribute_checkpointed_activations_helper(layer_number):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = (layer_number > 0)
is_first_pipeline_stage = (
mpu.get_pipeline_model_parallel_rank() == 0)
return self.distribute_checkpointed_activations and \
(not_first_layer_in_pipeline_stage or is_first_pipeline_stage)
if self.activations_checkpoint_method == 'uniform': if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk. # the input activation of each divided chunk.
...@@ -651,7 +634,7 @@ class ParallelTransformer(MegatronModule): ...@@ -651,7 +634,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l), self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.activations_checkpoint_method == 'block':
...@@ -662,7 +645,7 @@ class ParallelTransformer(MegatronModule): ...@@ -662,7 +645,7 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers: if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
distribute_checkpointed_activations_helper(l), self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
......
...@@ -24,7 +24,6 @@ from torch import _C ...@@ -24,7 +24,6 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment