Commit 0d77c0e9 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

refactor to help merge with main

parent 02bb1f5c
...@@ -97,8 +97,8 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -97,8 +97,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters # set sequence parallelism flag on weight and bias parameters
self.weight.sequence_parallel = self.sequence_parallel setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
self.bias.sequence_parallel = self.sequence_parallel setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self): def reset_parameters(self):
......
...@@ -26,21 +26,31 @@ from megatron.model.transformer import ParallelTransformer ...@@ -26,21 +26,31 @@ from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
args = get_args() args = get_args()
# Parallel logits. # Parallel logits.
if not args.model_parallel_memory_opt: if args.async_tensor_model_parallel_allreduce or\
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) args.model_parallel_memory_opt:
input_parallel = input
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel
model_parallel_memory_opt = args.model_parallel_memory_opt and \
model_parallel
else: else:
input_parallel = input_ input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
model_parallel_memory_opt = False
# Matrix multiply. # Matrix multiply.
if bias is None: logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
logits_parallel = F.linear(input_parallel, word_embeddings_weight) input_parallel, word_embeddings_weight, bias,
else: args.gradient_accumulation_fusion,
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) async_grad_allreduce, model_parallel_memory_opt)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
......
...@@ -202,56 +202,34 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -202,56 +202,34 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
""" """
Column-parallel linear layer execution with asynchronous all-reduce Linear layer execution with asynchronous communication and gradient accumulation
execution in backprop. fusion in backprop.
""" """
@staticmethod @staticmethod
def forward(ctx, input, weight, bias): def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, model_parallel_memory_opt):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t()) ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
if bias is not None: ctx.async_grad_allreduce = async_grad_allreduce
output = output + bias ctx.model_parallel_memory_opt = model_parallel_memory_opt
return output
if model_parallel_memory_opt:
@staticmethod world_size = get_tensor_model_parallel_world_size()
def backward(ctx, grad_output): dim_size = list(input.size())
input, weight = ctx.saved_tensors dim_size[0] = dim_size[0] * world_size
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight) total_input = torch.empty(dim_size, dtype=input.dtype,
# Asyncronous all-reduce device=torch.cuda.current_device(),
handle = torch.distributed.all_reduce( requires_grad=False)
grad_input, group=get_tensor_model_parallel_group(), async_op=True) torch.distributed._all_gather_base(total_input, input,
# Delay the start of weight gradient computation shortly (3us) to have group=get_tensor_model_parallel_group())
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinearWithSequenceParallelism(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group())
else:
total_input = input
output = torch.matmul(total_input, weight.t()) output = torch.matmul(total_input, weight.t())
if bias is not None: if bias is not None:
output = output + bias output = output + bias
...@@ -261,41 +239,72 @@ class ColumnParallelLinearWithSequenceParallelism(torch.autograd.Function): ...@@ -261,41 +239,72 @@ class ColumnParallelLinearWithSequenceParallelism(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have if ctx.model_parallel_memory_opt:
# gather scheduled first and have GPU resources allocated world_size = get_tensor_model_parallel_world_size()
_ = torch.empty(1, device=grad_output.device) + 1 dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
else:
total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
handle.wait()
dim_size = list(input.size()) if ctx.model_parallel_memory_opt:
sub_grad_input = torch.empty(dim_size, dtype=input.dtype, handle.wait()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.model_parallel_memory_opt:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
# reduce_scatter # reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(), async_op=True) group=get_tensor_model_parallel_group(),
async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# Delay the start of weight gradient computation shortly (3us) to have if ctx.gradient_accumulation_fusion:
# reduce scatter scheduled first and have GPU resources allocated fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
_ = torch.empty(1, device=grad_output.device) + 1 grad_weight = None
grad_weight = grad_output.t().matmul(total_input) else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return sub_grad_input, grad_weight, grad_bias if ctx.async_grad_allreducei:
handle.wait()
return grad_input, grad_weight, grad_bias
if ctx.model_parallel_memory_opt:
handle.wait()
return sub_grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -375,37 +384,25 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -375,37 +384,25 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and not args.no_async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.model_parallel_memory_opt = args.model_parallel_memory_opt self.model_parallel_memory_opt = (
args.model_parallel_memory_opt and
world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce: if self.async_tensor_model_parallel_allreduce or \
input_shape = input_.shape self.model_parallel_memory_opt:
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) input_parallel = input_
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else: else:
# Set up backprop all-reduce. input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.model_parallel_memory_opt: # Matrix multiply.
input_shape = input_.shape output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
output_parallel = ColumnParallelLinearWithSequenceParallelism.apply( self.async_tensor_model_parallel_allreduce, self.model_parallel_memory_opt)
input_, self.weight, bias)
world_size = get_tensor_model_parallel_world_size()
output_parallel = output_parallel.view(
input_shape[0] * world_size, input_shape[1], output_parallel.shape[1])
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.model_parallel_memory_opt assert not self.model_parallel_memory_opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment