Commit 4df8b7a2 authored by slym's avatar slym
Browse files

reflect feedback

parent 3f652469
...@@ -205,11 +205,11 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): ...@@ -205,11 +205,11 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
execution in backprop. execution in backprop.
""" """
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, use_bias): def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = use_bias ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
if use_bias: if bias is not None:
output = output + bias output = output + bias
return output return output
...@@ -227,7 +227,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): ...@@ -227,7 +227,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
grad_weight = grad_output.t().matmul(input) grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -318,7 +318,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -318,7 +318,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution # Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply( output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias, bias is not None) input_, self.weight, bias)
output_parallel = output_parallel.view( output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1]) input_shape[0], input_shape[1], output_parallel.shape[1])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment