Commit 9dc3c42a authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

preallocating global buffer to avoid memory fragmentation

parent 8474e6e5
......@@ -23,6 +23,7 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron
def print_rank_0(message):
......
......@@ -18,7 +18,8 @@
import os
import sys
import time
from functools import reduce
import operator
import torch
from megatron import dist_signal_handler
......@@ -33,7 +34,7 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args():
"""Return arguments."""
......@@ -77,15 +78,23 @@ def get_timers():
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def get_signal_handler():
_ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
......@@ -98,6 +107,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
_set_global_memory_buffer()
if args.exit_signal_handler:
_set_signal_handler()
......@@ -182,6 +192,12 @@ def _set_timers():
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
......@@ -273,3 +289,21 @@ class Timers:
print(string, flush=True)
else:
print(string, flush=True)
class GlobalMemoryBuffer:
"Global buffer to avoid dynamic memory allocations"
def __init__(self):
self.buffer = {}
def allocate_tensor(self, tensor_shape, dtype):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get(dtype, None) is None or self.buffer[dtype].numel() < required_len:
self.buffer[dtype] = torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[dtype][0:required_len].view(*tensor_shape)
......@@ -118,7 +118,7 @@ class Pooler(MegatronModule):
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states,
to_model_parallel=False)
tensor_parallel_output_grad=False)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
......
......@@ -19,7 +19,7 @@ from contextlib import nullcontext
import torch
import torch.nn.functional as F
from megatron import get_timers, get_args
from megatron import get_timers, get_args, get_global_memory_buffer
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
......@@ -234,12 +234,9 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
matmul_input_buffer = get_global_memory_buffer().allocate_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
dtype=query_layer.dtype)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
......
......@@ -39,7 +39,7 @@ from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args
from megatron import get_args, get_global_memory_buffer
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
......@@ -221,9 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype)
torch.distributed._all_gather_base(
all_gather_buffer,
input,
......@@ -248,10 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype)
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
......
......@@ -217,19 +217,23 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, to_model_parallel=True):
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, to_model_parallel=True):
ctx.to_model_parallel = to_model_parallel
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
to_model_parallel = ctx.to_model_parallel
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
if to_model_parallel:
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
......@@ -275,8 +279,8 @@ def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, to_model_parallel=True):
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment