Commit 14c0eab4 authored by Rick Ho's avatar Rick Ho
Browse files

fix another bug to make global moe run

parent 0fea2991
...@@ -47,7 +47,7 @@ class MOEScatter(Function): ...@@ -47,7 +47,7 @@ class MOEScatter(Function):
(fwd_batch_size, local_batch_size, world_size) = ctx.moe_args (fwd_batch_size, local_batch_size, world_size) = ctx.moe_args
if world_size > 1: if world_size > 1:
local_grad_in, = fmoe_cuda.global_gather(global_grad_out, local_grad_in, = fmoe_cuda.global_gather(global_grad_in,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
local_batch_size, world_size) local_batch_size, world_size)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment