Commit 4df8b7a2 authored by slym's avatar slym
Browse files

reflect feedback

parent 3f652469
......@@ -205,11 +205,11 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias, use_bias):
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = use_bias
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
if use_bias:
if bias is not None:
output = output + bias
return output
......@@ -227,7 +227,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias, None
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module):
......@@ -318,7 +318,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias, bias is not None)
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment