Commit 6900f1de authored by Rick Ho's avatar Rick Ho
Browse files

fix python bugs

parent 437afda2
from .moe import BruteForceMoE from .moe import BruteForceMoE
from .fmoe import FMoELinear, FMoENaiveGate, FMoETransformerMLP from .layers import FMoELinear, FMoENaiveGate, FMoETransformerMLP
...@@ -9,8 +9,8 @@ def moe_prepare_forward(gate, num_expert, world_size): ...@@ -9,8 +9,8 @@ def moe_prepare_forward(gate, num_expert, world_size):
with torch.no_grad(): with torch.no_grad():
_, pos = torch.sort(gate) _, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True) gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size, local_expert_count = torch.zeros(num_expert * world_size,
device=weight.device, dtype=torch.long) device=gate.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count) local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange( global_expert_count, = fmoe_cuda.expert_exchange(
...@@ -28,7 +28,7 @@ class MOEScatter(Function): ...@@ -28,7 +28,7 @@ class MOEScatter(Function):
fwd_batch_size, world_size): fwd_batch_size, world_size):
local_input_buf, = fmoe_cuda.local_gather(inp, pos) local_input_buf, = fmoe_cuda.local_gather(inp, pos)
if world_size > 1: if world_size > 1:
global_input_buf, = moe_cuda.global_scatter(local_input_buf, global_input_buf, = fmoe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
fwd_batch_size, world_size) fwd_batch_size, world_size)
else: else:
...@@ -43,19 +43,19 @@ class MOEScatter(Function): ...@@ -43,19 +43,19 @@ class MOEScatter(Function):
(fwd_batch_size, local_batch_size, world_size) = ctx.moe_args (fwd_batch_size, local_batch_size, world_size) = ctx.moe_args
if world_size > 1: if world_size > 1:
local_grad_in, = moe_cuda.global_gather(global_grad_out, local_grad_in, = fmoe_cuda.global_gather(global_grad_out,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
local_batch_size, world_size) local_batch_size, world_size)
else: else:
local_grad_in = global_grad_in local_grad_in = global_grad_in
grad_in, = moe_cuda.local_scatter(local_grad_in, pos) grad_in, = fmoe_cuda.local_scatter(local_grad_in, pos)
return grad_in, None, None, None, None, None return grad_in, None, None, None, None, None
class MOELinear(Function): class MOELinear(Function):
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = moe_cuda.forward(global_input_buf, weight, global_output_buf, = fmoe_cuda.forward(global_input_buf, weight,
fwd_expert_count) fwd_expert_count)
variables = (input_buf, weight, fwd_expert_count) variables = (input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
...@@ -74,12 +74,12 @@ class MOEGather(Function): ...@@ -74,12 +74,12 @@ class MOEGather(Function):
def forward(ctx, global_output_buf, pos, local_expert_count, def forward(ctx, global_output_buf, pos, local_expert_count,
global_expert_count, local_batch_size, world_size): global_expert_count, local_batch_size, world_size):
if world_size > 1: if world_size > 1:
local_output_buf, = moe_cuda.global_gather(global_output_buf, local_output_buf, = fmoe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
local_batch_size, world_size) local_batch_size, world_size)
else: else:
local_output_buf = global_output_buf local_output_buf = global_output_buf
output, = moe_cuda.local_scatter(local_output_buf, pos) output, = fmoe_cuda.local_scatter(local_output_buf, pos)
ctx.moe_args = fwd_batch_size, world_size ctx.moe_args = fwd_batch_size, world_size
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
...@@ -90,9 +90,9 @@ class MOEGather(Function): ...@@ -90,9 +90,9 @@ class MOEGather(Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = moe_cuda.local_gather(grad_out.contiguous(), pos) grad_out_buf = fmoe_cuda.local_gather(grad_out.contiguous(), pos)
if world_size > 1: if world_size > 1:
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf, global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
fwd_batch_size, world_size) fwd_batch_size, world_size)
else: else:
......
from .fmoe_functions import * from .fmoe_functions import *
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class FMoELinear(nn.Module): class FMoELinear(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(FMoE, self).__init__() super(FMoELinear, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -21,10 +22,11 @@ class FMoELinear(nn.Module): ...@@ -21,10 +22,11 @@ class FMoELinear(nn.Module):
return MOELinear.apply(inp, self.weight, fwd_expert_count) return MOELinear.apply(inp, self.weight, fwd_expert_count)
class FMoENaiveGate(nn.module): class FMoENaiveGate(nn.Module):
def __init__(self, num_expert=32, world_size=1, top_k=2): def __init__(self, d_model, num_expert, world_size, top_k=2):
super(FMoENaiveGate, self).__init__() super(FMoENaiveGate, self).__init__()
self.gate = nn.Linear(d_model, num_expert * world_size) self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp): def forward(self, inp):
gate = self.gate(inp) gate = self.gate(inp)
...@@ -53,7 +55,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): ...@@ -53,7 +55,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
return x return x
class FMoETransformerMLP(nn.module): class FMoETransformerMLP(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096, def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu, world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False): top_k=2, pre_lnorm=False):
...@@ -64,11 +66,12 @@ class FMoETransformerMLP(nn.module): ...@@ -64,11 +66,12 @@ class FMoETransformerMLP(nn.module):
self.world_size = world_size self.world_size = world_size
self.activation = activation self.activation = activation
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.top_k = top_k
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model) self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.gate = FMoENaivegate(num_expert, world_size, top_k) self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model, self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
......
from torch import nn from .layers import FMoETransformerMLP
from .moe import FMoE
from .moe_function import moe
from .fmoe import FMoETransformerMLP
class FFFN(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FFFN, self).__init__()
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.activation = activation
self.top_k = top_k
self.pre_lnorm = pre_lnorm
self.htoh4 = FMoE(num_expert, d_model, d_hidden,
world_size=world_size)
self.h4toh = FMoE(num_expert, d_hidden, d_model,
world_size=world_size)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = self.htoh4(inp, gate_top_k_idx)
x = self.activation(x)
x = self.h4toh(x, gate_top_k_idx)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
def create_moe_mlp(args): def create_moe_mlp(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment