Commit 952e3135 authored by Rick Ho's avatar Rick Ho
Browse files

fix scatter/gather bug to make it correct

parent 2d250fbf
......@@ -14,7 +14,7 @@ class MOELocal(Function):
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc = expert_count.cpu()
input_buf, = fmoe_cuda.local_scatter(inp, pos)
input_buf, = fmoe_cuda.local_gather(inp, pos)
output_buf, = fmoe_cuda.forward(input_buf, weight, ecc)
output = fmoe_cuda.local_gather(output_buf, pos)
......@@ -52,13 +52,12 @@ class MOEGlobal(Function):
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
print('Local {} Global {}'.format(local_expert_count, global_expert_count))
fwd_expert_count = global_expert_count.view(num_expert,
world_size).sum(dim=1).cpu()
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0).cpu()
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
local_input_buf, = fmoe_cuda.local_gather(inp, pos)
local_expert_count = local_expert_count.cpu()
global_expert_count = global_expert_count.cpu()
......@@ -67,7 +66,7 @@ class MOEGlobal(Function):
local_expert_count, global_expert_count,
fwd_batch_size, inp.shape[0], world_size)
output, = fmoe_cuda.local_gather(local_output_buf, pos)
output, = fmoe_cuda.local_scatter(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
......
......@@ -135,9 +135,9 @@ def test():
print('Rank {} {} abs err {}'.format(rank, name, err))
if err > 1e-3:
sys.stderr.write('=========== moe out ==============\n')
sys.stderr.write('{}'.format(mo))
sys.stderr.write('{}\n'.format(mo))
sys.stderr.write('=========== raw out ==============\n')
sys.stderr.write('{}'.format(ro))
sys.stderr.write('{}\n'.format(ro))
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment