Commit 90c4bccf authored by Rich Ho's avatar Rich Ho
Browse files

fix scatter bug across gpus

parent 670d70ed
...@@ -117,7 +117,7 @@ class MOEScatter(Function): ...@@ -117,7 +117,7 @@ class MOEScatter(Function):
) )
else: else:
global_input_buf = local_input_buf global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], world_size ctx.moe_args = inp.shape[0], pos.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_input_buf return global_input_buf
...@@ -125,19 +125,19 @@ class MOEScatter(Function): ...@@ -125,19 +125,19 @@ class MOEScatter(Function):
@staticmethod @staticmethod
def backward(ctx, global_grad_in): def backward(ctx, global_grad_in):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensors (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
(local_batch_size, world_size) = ctx.moe_args (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
if world_size > 1: if world_size > 1:
(local_grad_in,) = fmoe_cuda.global_gather( (local_grad_in,) = fmoe_cuda.global_gather(
global_grad_in, global_grad_in,
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
local_batch_size, buf_batch_size,
world_size, world_size,
) )
else: else:
local_grad_in = global_grad_in local_grad_in = global_grad_in
grad_in = _local_gather(local_grad_in, pos, local_batch_size) grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None return grad_in, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment