Commit 414a2f86 authored by Rick Ho's avatar Rick Ho
Browse files

remove repeate interleave and local scatter

parent f804a121
...@@ -74,6 +74,18 @@ def prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -74,6 +74,18 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
) )
def _local_scatter(inp, pos):
inp_buf = torch.index_select(inp, 0, pos)
return inp_buf
def _local_gather(inp, pos, out_batch_size):
inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
dtype=inp.dtype, device=inp.device)
inp_buf.index_copy_(0, pos, inp)
return inp_buf
class MOEScatter(Function): class MOEScatter(Function):
r""" r"""
Scatter input samples from [batch x sequences] to contiguous alone experts. Scatter input samples from [batch x sequences] to contiguous alone experts.
...@@ -91,7 +103,7 @@ class MOEScatter(Function): ...@@ -91,7 +103,7 @@ class MOEScatter(Function):
fwd_batch_size, fwd_batch_size,
world_size, world_size,
): ):
(local_input_buf,) = fmoe_cuda.local_scatter(inp, pos) local_input_buf = _local_scatter(inp, pos)
if world_size > 1: if world_size > 1:
(global_input_buf,) = fmoe_cuda.global_scatter( (global_input_buf,) = fmoe_cuda.global_scatter(
local_input_buf, local_input_buf,
...@@ -122,7 +134,7 @@ class MOEScatter(Function): ...@@ -122,7 +134,7 @@ class MOEScatter(Function):
) )
else: else:
local_grad_in = global_grad_in local_grad_in = global_grad_in
(grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos) grad_in = _local_gather(local_grad_in, pos, local_batch_size)
return grad_in, None, None, None, None, None return grad_in, None, None, None, None, None
...@@ -175,7 +187,7 @@ class MOEGather(Function): ...@@ -175,7 +187,7 @@ class MOEGather(Function):
) )
else: else:
local_output_buf = global_output_buf local_output_buf = global_output_buf
(output,) = fmoe_cuda.local_gather(local_output_buf, pos) output = _local_gather(local_output_buf, pos, local_batch_size)
ctx.moe_args = (global_output_buf.shape[0], world_size) ctx.moe_args = (global_output_buf.shape[0], world_size)
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
...@@ -186,7 +198,7 @@ class MOEGather(Function): ...@@ -186,7 +198,7 @@ class MOEGather(Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args fwd_batch_size, world_size = ctx.moe_args
(grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos) grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
if world_size > 1: if world_size > 1:
(global_grad_out_buf,) = fmoe_cuda.global_scatter( (global_grad_out_buf,) = fmoe_cuda.global_scatter(
grad_out_buf, grad_out_buf,
......
...@@ -114,12 +114,19 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -114,12 +114,19 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size) ) = prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply( x = MOEScatter.apply(
inp, pos, inp, pos % inp.shape[0],
local_expert_count, global_expert_count, fwd_batch_size, world_size local_expert_count, global_expert_count, fwd_batch_size, world_size
) )
x = expert_fn(x, fwd_expert_count) x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0]
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
x = MOEGather.apply( x = MOEGather.apply(
x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size x, pos,
local_expert_count, global_expert_count,
out_batch_size, world_size
) )
return x return x
...@@ -216,16 +223,14 @@ class FMoE(nn.Module): ...@@ -216,16 +223,14 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
# TODO: remove repeat_interleave
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward( x = _fmoe_general_global_forward(
inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size inp,
gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
) )
# to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model) x = x.view(inp.shape[0], self.top_k, self.d_model)
# to: (BxL) x d_model gate_score = gate_score.view(inp.shape[0], 1, self.top_k)
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1: if self.mp_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment