"src/vscode:/vscode.git/clone" did not exist on "d87ce2cefc6612fa95cb6d58fa3d74080d18b312"
Commit 8bac18dc authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Updated arguments for MOELinear.apply

parent a3b2eb62
...@@ -110,21 +110,21 @@ class MOELinear(Function): ...@@ -110,21 +110,21 @@ class MOELinear(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.forward( (global_output_buf,) = fmoe_cuda.forward(
global_input_buf, weight, fwd_expert_count global_input_buf, weight, fwd_expert_count
) )
variables = (global_input_buf, weight, fwd_expert_count) variables = (global_input_buf, fwd_expert_count, weight)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_output_buf return global_output_buf
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors (input_buf, fwd_expert_count, weight) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward( grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count grad_out, input_buf, weight, fwd_expert_count
) )
return grad_inp_buf, grad_weight, None return grad_inp_buf, None, grad_weight
class MOEGather(Function): class MOEGather(Function):
......
...@@ -41,7 +41,7 @@ class FMoELinear(nn.Module): ...@@ -41,7 +41,7 @@ class FMoELinear(nn.Module):
r""" r"""
Call MOE function Call MOE function
""" """
x = MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, fwd_expert_count, self.weight)
if self.bias is not None: if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical # TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect # instability in backward, leading to incorrect
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment