Commit 739cb43d authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

resolved review comments

parent 9dc3c42a
...@@ -304,7 +304,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -304,7 +304,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.recompute_method is None, \ assert args.recompute_method is None, \
'recompute method is not yet supported for ' \ 'recompute method is not yet supported for ' \
'selective recomputing granularity' 'selective recomputing granularity'
# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if args.tensor_model_parallel_size == 1:
args.sequence_parallel = False
# disable async_tensor_model_parallel_allreduce when # disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled # model parallel memory optimization is enabled
if args.sequence_parallel: if args.sequence_parallel:
......
...@@ -292,18 +292,21 @@ class Timers: ...@@ -292,18 +292,21 @@ class Timers:
class GlobalMemoryBuffer: class GlobalMemoryBuffer:
"Global buffer to avoid dynamic memory allocations" """Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self): def __init__(self):
self.buffer = {} self.buffer = {}
def allocate_tensor(self, tensor_shape, dtype): def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1) required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get(dtype, None) is None or self.buffer[dtype].numel() < required_len: if self.buffer.get((name, dtype), None) is None or \
self.buffer[dtype] = torch.empty(required_len, self.buffer[(name, dtype)].numel() < required_len:
dtype=dtype, self.buffer[(name, dtype)] = \
device=torch.cuda.current_device(), torch.empty(required_len,
requires_grad=False) dtype=dtype,
device=torch.cuda.current_device(),
return self.buffer[dtype][0:required_len].view(*tensor_shape) requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
...@@ -234,9 +234,9 @@ class CoreAttention(MegatronModule): ...@@ -234,9 +234,9 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().allocate_tensor( matmul_input_buffer = get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]), (output_size[0]*output_size[1], output_size[2], output_size[3]),
dtype=query_layer.dtype) query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
......
...@@ -221,7 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -221,7 +221,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \ all_gather_buffer = \
get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype) get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
all_gather_buffer, all_gather_buffer,
input, input,
...@@ -246,7 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -246,7 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \ all_gather_buffer = \
get_global_memory_buffer().allocate_tensor(dim_size, dtype=input.dtype) get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base( handle = torch.distributed._all_gather_base(
all_gather_buffer, all_gather_buffer,
input, input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment