Commit 0fea2991 authored by Rick Ho's avatar Rick Ho
Browse files

fix more bugs to make the layers run in the model

parent 6900f1de
...@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size): ...@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size):
device=gate.device, dtype=torch.long) device=gate.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count) local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange( if world_size > 1:
local_expert_count, num_expert, world_size) global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
else:
global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0).cpu() num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
return (pos, local_expert_count.cpu(), global_expert_count.cpu(), return (pos, local_expert_count.cpu(), global_expert_count.cpu(),
fwd_expert_count.cpu(), fwd_batch_size) fwd_expert_count.cpu(), fwd_batch_size)
...@@ -35,6 +38,7 @@ class MOEScatter(Function): ...@@ -35,6 +38,7 @@ class MOEScatter(Function):
global_input_buf = local_input_buf global_input_buf = local_input_buf
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf return global_input_buf
@staticmethod @staticmethod
...@@ -57,14 +61,14 @@ class MOELinear(Function): ...@@ -57,14 +61,14 @@ class MOELinear(Function):
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = fmoe_cuda.forward(global_input_buf, weight, global_output_buf, = fmoe_cuda.forward(global_input_buf, weight,
fwd_expert_count) fwd_expert_count)
variables = (input_buf, weight, fwd_expert_count) variables = (global_input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_output_buf return global_output_buf
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
grad_inp_buf, grad_weight = ome_cuda.backward( grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count) grad_out, input_buf, weight, fwd_expert_count)
return grad_inp_buf, grad_weight, None return grad_inp_buf, grad_weight, None
...@@ -81,7 +85,7 @@ class MOEGather(Function): ...@@ -81,7 +85,7 @@ class MOEGather(Function):
local_output_buf = global_output_buf local_output_buf = global_output_buf
output, = fmoe_cuda.local_scatter(local_output_buf, pos) output, = fmoe_cuda.local_scatter(local_output_buf, pos)
ctx.moe_args = fwd_batch_size, world_size ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output return output
...@@ -89,8 +93,8 @@ class MOEGather(Function): ...@@ -89,8 +93,8 @@ class MOEGather(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = fmoe_cuda.local_gather(grad_out.contiguous(), pos) grad_out_buf, = fmoe_cuda.local_gather(grad_out.contiguous(), pos)
if world_size > 1: if world_size > 1:
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf, global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
......
...@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): ...@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
for i, l in enumerate(linears): for i, l in enumerate(linears):
if i: if i:
x = activation(x) x = activation(x)
x = l(x) x = l(x, fwd_expert_count)
x = MOEGather.apply(x, pos, local_expert_count, global_expert_count, x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
inp.shape[0], world_size) inp.shape[0], world_size)
return x return x
...@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module): ...@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module):
dtype=torch.float32)) dtype=torch.float32))
def forward(self, inp): def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp residual = inp
if self.pre_lnorm: if self.pre_lnorm:
inp = self.layer_norm(inp) inp = self.layer_norm(inp)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# TODO: merge replication into local_scatter
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = _fmoe_full_forward(inp, gate_top_k_idx, x = _fmoe_full_forward(inp, gate_top_k_idx,
[self.htoh4, self.h4toh], self.activation, [self.htoh4, self.h4toh], self.activation,
self.num_expert, self.world_size) self.num_expert, self.world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment