Commit 136d63cb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'dist_chkpt_act' into 'main'

Revisited distributing checkpointed activations along the tensor parallel ranks

See merge request ADLR/megatron-lm!311
parents 0be40526 cb5e611d
......@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
_print_args(args)
return args
......
......@@ -77,9 +77,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume.
_init_autoresume()
......@@ -224,11 +221,3 @@ def write_args_to_tensorboard():
writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration)
def _initialize_mem_buffs():
"""Initialize manually allocated static memory."""
args = get_args()
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
......@@ -544,6 +544,7 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
......@@ -607,8 +608,22 @@ class ParallelTransformer(MegatronModule):
return x_
return custom_forward
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
def distribute_checkpointed_activations_helper(layer_number):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = (layer_number > 0)
is_first_pipeline_stage = (
mpu.get_pipeline_model_parallel_rank() == 0)
return self.distribute_checkpointed_activations and \
(not_first_layer_in_pipeline_stage or is_first_pipeline_stage)
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
......@@ -618,6 +633,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
......@@ -628,6 +644,7 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
......
......@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
......
......@@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
if args.virtual_pipeline_model_parallel_size is not None:
num_layers = num_layers // args.virtual_pipeline_model_parallel_size
if args.activations_checkpoint_method == 'uniform':
assert num_layers % args.activations_checkpoint_num_layers == 0, \
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
elif args.activations_checkpoint_method == 'block':
assert args.activations_checkpoint_num_layers <= num_layers, \
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers = args.activations_checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
......@@ -110,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
......@@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, *args):
def forward(ctx, run_function, distribute_checkpointed_activations, *args):
ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data)
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
new_buffer=True)
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
......@@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
......@@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
return (None, None) + grads
def checkpoint(function, *args):
def checkpoint(function, distribute_checkpointed_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment