You need to sign in or sign up before continuing.
Commit 12518332 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'dist_act_chkpt' into 'main'

added splitting checkpointed activations across model parallel partitions

See merge request ADLR/megatron-lm!121
parents 930ec4a2 5d29769c
...@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={},
# Mixed precision checks. # Mixed precision checks.
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
_print_args(args) _print_args(args)
return args return args
...@@ -200,6 +205,10 @@ def _add_training_args(parser): ...@@ -200,6 +205,10 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training ' help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.') 'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.') help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
......
...@@ -72,6 +72,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -72,6 +72,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
else: else:
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init() finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
...@@ -151,3 +154,12 @@ def _write_args_to_tensorboard(): ...@@ -151,3 +154,12 @@ def _write_args_to_tensorboard():
if writer: if writer:
for arg in vars(args): for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg))) writer.add_text(arg, str(getattr(args, arg)))
def _initialize_mem_buffs():
"""Initialize manually allocated static memory."""
args = get_args()
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# A dictionary of all the memory buffers allocated.
_MEM_BUFFS = dict()
def allocate_mem_buff(name, numel, dtype, track_usage):
"""Allocate a memory buffer."""
assert name not in _MEM_BUFFS, \
'memory buffer {} already allocated.'.format(name)
_MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
return _MEM_BUFFS[name]
def get_mem_buff(name):
"""Get the memory buffer."""
return _MEM_BUFFS[name]
class MemoryBuffer:
"""Contiguous memory buffer.
Allocate a contiguous memory of type `dtype` and size `numel`. It is
used to reduce memory fragmentation.
Usage: After the allocation, the `_start` index is set tot the first
index of the memory. A memory chunk starting from `_start` index
can be `allocated` for an input tensor, with the elements of the
tensor being coppied. The buffer can be reused by resetting the
`_start` index.
"""
def __init__(self, name, numel, dtype, track_usage):
if torch.distributed.get_rank() == 0:
element_size = torch.tensor([], dtype=dtype).element_size()
print('> building the {} memory buffer with {} num elements '
'and {} dtype ({:.1f} MB)...'.format(
name, numel, dtype, numel*element_size/1024/1024),
flush=True)
self.name = name
self.numel = numel
self.dtype = dtype
self.data = torch.empty(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# Index tracking the start of the free memory.
self._start = 0
# Values used for tracking usage.
self.track_usage = track_usage
if self.track_usage:
self.in_use_value = 0.0
self.total_value = 0.0
def reset(self):
"""Reset the buffer start index to the beginning of the buffer."""
self._start = 0
def is_in_use(self):
"""Whether the current buffer hold on to any memory."""
return self._start > 0
def numel_in_use(self):
"""Return number of elements in use."""
return self._start
def add(self, tensor):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert tensor.dtype == self.dtype, \
'Input tensor type {} different from buffer type {}'.format(
tensor.dtype, self.dtype)
# Number of elements of the input tensor.
tensor_numel = torch.numel(tensor)
new_start = self._start + tensor_numel
assert new_start <= self.numel, \
'Not enough memory left in the buffer ({} > {})'.format(
tensor_numel, self.numel - self._start)
# New tensor is a view into the memory.
new_tensor = self.data[self._start:new_start]
self._start = new_start
new_tensor = new_tensor.view(tensor.shape)
new_tensor.copy_(tensor)
# Return a pointer to the new tensor.
return new_tensor
def get_data(self):
"""Return the data currently in use."""
if self.track_usage:
self.in_use_value += float(self._start)
self.total_value += float(self.numel)
return self.data[:self._start]
def print_average_usage(self):
"""Print memory usage average over time. We would like this value
to be as high as possible."""
assert self.track_usage, 'You need to enable track usage.'
if torch.distributed.get_rank() == 0:
print(' > usage of {} memory buffer: {:.2f} %'.format(
self.name, self.in_use_value * 100.0 / self.total_value),
flush=True)
class RingMemBuffer:
"""A ring of memory buffers."""
def __init__(self, name, num_buffers, numel, dtype, track_usage):
self.num_buffers = num_buffers
self.buffers = [
allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)
for i in range(num_buffers)]
self._index = -1
def get_next_buffer(self):
self._index += 1
self._index = self._index % self.num_buffers
buff = self.buffers[self._index]
assert not buff.is_in_use(), 'buffer is already in use.'
return buff
...@@ -411,6 +411,8 @@ class ParallelTransformer(MegatronModule): ...@@ -411,6 +411,8 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
......
...@@ -45,7 +45,9 @@ from .mappings import scatter_to_model_parallel_region ...@@ -45,7 +45,9 @@ from .mappings import scatter_to_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -24,14 +24,50 @@ from torch import _C ...@@ -24,14 +24,50 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron import get_args
from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
...@@ -65,6 +101,29 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -65,6 +101,29 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_model_parallel_world_size()
start_index = partition_size * get_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor,
group=get_model_parallel_group())
return gathered
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -199,9 +258,21 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -199,9 +258,21 @@ class CheckpointFunction(torch.autograd.Function):
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
ctx.save_for_backward(*args)
with torch.no_grad(): with torch.no_grad():
outputs = run_function(*args) outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data)
# Store everything.
ctx.save_for_backward(*args)
return outputs return outputs
@staticmethod @staticmethod
...@@ -210,6 +281,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -210,6 +281,9 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment