Commit 9dc3c42a authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

preallocating global buffer to avoid memory fragmentation

parent 8474e6e5
...@@ -23,6 +23,7 @@ from .global_vars import get_tokenizer ...@@ -23,6 +23,7 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron from .initialize import initialize_megatron
def print_rank_0(message): def print_rank_0(message):
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
import os import os
import sys import sys
import time import time
from functools import reduce
import operator
import torch import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
...@@ -33,7 +34,7 @@ _GLOBAL_TENSORBOARD_WRITER = None ...@@ -33,7 +34,7 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None _GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args(): def get_args():
"""Return arguments.""" """Return arguments."""
...@@ -77,15 +78,23 @@ def get_timers(): ...@@ -77,15 +78,23 @@ def get_timers():
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def get_signal_handler(): def get_signal_handler():
_ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
return _GLOBAL_SIGNAL_HANDLER return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler(): def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__() _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
def set_global_variables(extra_args_provider=None, args_defaults={}, def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
...@@ -98,6 +107,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -98,6 +107,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
_set_global_memory_buffer()
if args.exit_signal_handler: if args.exit_signal_handler:
_set_signal_handler() _set_signal_handler()
...@@ -182,6 +192,12 @@ def _set_timers(): ...@@ -182,6 +192,12 @@ def _set_timers():
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers() _GLOBAL_TIMERS = Timers()
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name): def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None.""" """Make sure the input variable is not None."""
...@@ -273,3 +289,21 @@ class Timers: ...@@ -273,3 +289,21 @@ class Timers:
print(string, flush=True) print(string, flush=True)
else: else:
print(string, flush=True) print(string, flush=True)
class GlobalMemoryBuffer:
"Global buffer to avoid dynamic memory allocations"
def __init__(self):
self.buffer = {}
def allocate_tensor(self, tensor_shape, dtype):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get(dtype, None) is None or self.buffer[dtype].numel() < required_len:
self.buffer[dtype] = torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[dtype][0:required_len].view(*tensor_shape)
...@@ -118,7 +118,7 @@ class Pooler(MegatronModule): ...@@ -118,7 +118,7 @@ class Pooler(MegatronModule):
if self.sequence_parallel: if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region( hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states, hidden_states,
to_model_parallel=False) tensor_parallel_output_grad=False)
pooled = hidden_states[sequence_index, :, :] pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled) pooled = self.dense(pooled)
......
...@@ -19,7 +19,7 @@ from contextlib import nullcontext ...@@ -19,7 +19,7 @@ from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_timers, get_args from megatron import get_timers, get_args, get_global_memory_buffer
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
...@@ -234,12 +234,9 @@ class CoreAttention(MegatronModule): ...@@ -234,12 +234,9 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty( matmul_input_buffer = get_global_memory_buffer().allocate_tensor(
output_size[0]*output_size[1], (output_size[0]*output_size[1], output_size[2], output_size[3]),
output_size[2], dtype=query_layer.dtype)
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
......
...@@ -39,7 +39,7 @@ from .random import get_cuda_rng_tracker ...@@ -39,7 +39,7 @@ from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args, get_global_memory_buffer
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
...@@ -221,9 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -221,9 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \ all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype, get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype)
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
all_gather_buffer, all_gather_buffer,
input, input,
...@@ -248,10 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -248,10 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \ all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype, get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype)
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base( handle = torch.distributed._all_gather_base(
all_gather_buffer, all_gather_buffer,
input, input,
......
...@@ -217,19 +217,23 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function): ...@@ -217,19 +217,23 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate.""" """Gather the input from sequence parallel region and concatinate."""
@staticmethod @staticmethod
def symbolic(graph, input_, to_model_parallel=True): def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_) return _gather_along_first_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_, to_model_parallel=True): def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.to_model_parallel = to_model_parallel ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_) return _gather_along_first_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
to_model_parallel = ctx.to_model_parallel tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
if to_model_parallel: # If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None return _reduce_scatter_along_first_dim(grad_output), None
else: else:
return _split_along_first_dim(grad_output), None return _split_along_first_dim(grad_output), None
...@@ -275,8 +279,8 @@ def scatter_to_sequence_parallel_region(input_): ...@@ -275,8 +279,8 @@ def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_) return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, to_model_parallel=True): def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel) return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_): def reduce_scatter_to_sequence_parallel_region(input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment