Commit 0fea2991 authored by Rick Ho's avatar Rick Ho
Browse files

fix more bugs to make the layers run in the model

parent 6900f1de
......@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size):
device=gate.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
if world_size > 1:
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
else:
global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0).cpu()
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item())
return (pos, local_expert_count.cpu(), global_expert_count.cpu(),
fwd_expert_count.cpu(), fwd_batch_size)
......@@ -35,6 +38,7 @@ class MOEScatter(Function):
global_input_buf = local_input_buf
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
@staticmethod
......@@ -57,14 +61,14 @@ class MOELinear(Function):
def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = fmoe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
variables = (input_buf, weight, fwd_expert_count)
variables = (global_input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = ome_cuda.backward(
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count)
return grad_inp_buf, grad_weight, None
......@@ -81,7 +85,7 @@ class MOEGather(Function):
local_output_buf = global_output_buf
output, = fmoe_cuda.local_scatter(local_output_buf, pos)
ctx.moe_args = fwd_batch_size, world_size
ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
......@@ -89,8 +93,8 @@ class MOEGather(Function):
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = fmoe_cuda.local_gather(grad_out.contiguous(), pos)
local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = fmoe_cuda.local_gather(grad_out.contiguous(), pos)
if world_size > 1:
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
......
......@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
for i, l in enumerate(linears):
if i:
x = activation(x)
x = l(x)
x = l(x, fwd_expert_count)
x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
inp.shape[0], world_size)
return x
......@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module):
dtype=torch.float32))
def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
gate_top_k_idx, gate_score = self.gate(inp)
# TODO: merge replication into local_scatter
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = _fmoe_full_forward(inp, gate_top_k_idx,
[self.htoh4, self.h4toh], self.activation,
self.num_expert, self.world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment