Commit e923ec52 authored by mshoeybi's avatar mshoeybi
Browse files

removed contiguous buffer for checkpointed activation

parent 6a0ef5b1
......@@ -77,9 +77,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume.
_init_autoresume()
......@@ -224,11 +221,3 @@ def write_args_to_tensorboard():
writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration)
def _initialize_mem_buffs():
"""Initialize manually allocated static memory."""
args = get_args()
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
......@@ -544,6 +544,8 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
self.distribute_checkpointed_activations \
= args.distribute_checkpointed_activations
# Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
......@@ -607,12 +609,11 @@ class ParallelTransformer(MegatronModule):
return x_
return custom_forward
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
......
......@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
......
......@@ -37,37 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
......@@ -101,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
......@@ -250,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, *args):
def forward(ctx, run_function, distribute_checkpointed_activations, *args):
ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -263,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data)
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
new_buffer=True)
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
......@@ -281,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
......@@ -310,7 +286,7 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
return (None, None) + grads
def checkpoint(function, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment