"vscode:/vscode.git/clone" did not exist on "cc26cd8139c672016b6a578ea8d02138b53eb193"
Commit 02bb1f5c authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

column parallel linear with sequence parallelism

parent 6658158b
......@@ -232,6 +232,71 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinearWithSequenceParallelism(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group())
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_input = grad_output.matmul(weight)
handle.wait()
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return sub_grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -328,12 +393,18 @@ class ColumnParallelLinear(torch.nn.Module):
else:
# Set up backprop all-reduce.
if self.model_parallel_memory_opt:
input_parallel = gather_from_sequence_parallel_region(input_)
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
output_parallel = ColumnParallelLinearWithSequenceParallelism.apply(
input_, self.weight, bias)
world_size = get_tensor_model_parallel_world_size()
output_parallel = output_parallel.view(
input_shape[0] * world_size, input_shape[1], output_parallel.shape[1])
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment