Commit 79f16297 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix

parent 794c7e54
...@@ -53,7 +53,7 @@ std::vector<torch::Tensor> moe_backward( ...@@ -53,7 +53,7 @@ std::vector<torch::Tensor> moe_backward(
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_forward(input, gate, weight); return moe_cuda_backward(grad_output, input, gate, weight);
} }
......
...@@ -49,3 +49,7 @@ input = torch.rand(batch_size, in_feat).cuda() ...@@ -49,3 +49,7 @@ input = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, )).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, )).int().cuda()
output = moe(input, gate) output = moe(input, gate)
y = output.mean()
y.backward()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment