Commit 414a2f86 authored by Rick Ho's avatar Rick Ho
Browse files

remove repeate interleave and local scatter

parent f804a121
......@@ -74,6 +74,18 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
)
def _local_scatter(inp, pos):
inp_buf = torch.index_select(inp, 0, pos)
return inp_buf
def _local_gather(inp, pos, out_batch_size):
inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
dtype=inp.dtype, device=inp.device)
inp_buf.index_copy_(0, pos, inp)
return inp_buf
class MOEScatter(Function):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
......@@ -91,7 +103,7 @@ class MOEScatter(Function):
fwd_batch_size,
world_size,
):
(local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
(global_input_buf,) = fmoe_cuda.global_scatter(
local_input_buf,
......@@ -122,7 +134,7 @@ class MOEScatter(Function):
)
else:
local_grad_in = global_grad_in
(grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
grad_in = _local_gather(local_grad_in, pos, local_batch_size)
return grad_in, None, None, None, None, None
......@@ -175,7 +187,7 @@ class MOEGather(Function):
)
else:
local_output_buf = global_output_buf
(output,) = fmoe_cuda.local_gather(local_output_buf, pos)
output = _local_gather(local_output_buf, pos, local_batch_size)
ctx.moe_args = (global_output_buf.shape[0], world_size)
variables = (pos, local_expert_count, global_expert_count)
......@@ -186,7 +198,7 @@ class MOEGather(Function):
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args
(grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
if world_size > 1:
(global_grad_out_buf,) = fmoe_cuda.global_scatter(
grad_out_buf,
......
......@@ -114,12 +114,19 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(
inp, pos,
inp, pos % inp.shape[0],
local_expert_count, global_expert_count, fwd_batch_size, world_size
)
x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0]
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
x = MOEGather.apply(
x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
x, pos,
local_expert_count, global_expert_count,
out_batch_size, world_size
)
return x
......@@ -216,16 +223,14 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
# TODO: remove repeat_interleave
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(
inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
inp,
gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
)
# to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model
gate_score = gate_score.unsqueeze(1)
x = x.view(inp.shape[0], self.top_k, self.d_model)
gate_score = gate_score.view(inp.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment