Commit 79f16297 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix

parent 794c7e54
......@@ -53,7 +53,7 @@ std::vector<torch::Tensor> moe_backward(
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_forward(input, gate, weight);
return moe_cuda_backward(grad_output, input, gate, weight);
}
......
......@@ -49,3 +49,7 @@ input = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, )).int().cuda()
output = moe(input, gate)
y = output.mean()
y.backward()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment