Commit 8bac18dc authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Updated arguments for MOELinear.apply

parent a3b2eb62
......@@ -110,21 +110,21 @@ class MOELinear(Function):
"""
@staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count):
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.forward(
global_input_buf, weight, fwd_expert_count
)
variables = (global_input_buf, weight, fwd_expert_count)
variables = (global_input_buf, fwd_expert_count, weight)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors
(input_buf, fwd_expert_count, weight) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count
)
return grad_inp_buf, grad_weight, None
return grad_inp_buf, None, grad_weight
class MOEGather(Function):
......
......@@ -41,7 +41,7 @@ class FMoELinear(nn.Module):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, self.weight, fwd_expert_count)
x = MOELinear.apply(inp, fwd_expert_count, self.weight)
if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment