"server/text_generation_server/models/causal_lm.py" did not exist on "1539d3cbbef294df9c7ee9db07f45c97e6370f7b"
Commit 6900f1de authored by Rick Ho's avatar Rick Ho
Browse files

fix python bugs

parent 437afda2
from .moe import BruteForceMoE
from .fmoe import FMoELinear, FMoENaiveGate, FMoETransformerMLP
from .layers import FMoELinear, FMoENaiveGate, FMoETransformerMLP
......@@ -9,8 +9,8 @@ def moe_prepare_forward(gate, num_expert, world_size):
with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size,
device=weight.device, dtype=torch.long)
local_expert_count = torch.zeros(num_expert * world_size,
device=gate.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange(
......@@ -28,7 +28,7 @@ class MOEScatter(Function):
fwd_batch_size, world_size):
local_input_buf, = fmoe_cuda.local_gather(inp, pos)
if world_size > 1:
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
global_input_buf, = fmoe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
else:
......@@ -43,19 +43,19 @@ class MOEScatter(Function):
(fwd_batch_size, local_batch_size, world_size) = ctx.moe_args
if world_size > 1:
local_grad_in, = moe_cuda.global_gather(global_grad_out,
local_grad_in, = fmoe_cuda.global_gather(global_grad_out,
local_expert_count, global_expert_count,
local_batch_size, world_size)
else:
local_grad_in = global_grad_in
grad_in, = moe_cuda.local_scatter(local_grad_in, pos)
grad_in, = fmoe_cuda.local_scatter(local_grad_in, pos)
return grad_in, None, None, None, None, None
class MOELinear(Function):
@staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count):
global_output_buf, = moe_cuda.forward(global_input_buf, weight,
global_output_buf, = fmoe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
variables = (input_buf, weight, fwd_expert_count)
ctx.save_for_backward(*variables)
......@@ -74,12 +74,12 @@ class MOEGather(Function):
def forward(ctx, global_output_buf, pos, local_expert_count,
global_expert_count, local_batch_size, world_size):
if world_size > 1:
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_output_buf, = fmoe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
else:
local_output_buf = global_output_buf
output, = moe_cuda.local_scatter(local_output_buf, pos)
output, = fmoe_cuda.local_scatter(local_output_buf, pos)
ctx.moe_args = fwd_batch_size, world_size
variables = (pos, local_expert_count, global_expert_count)
......@@ -90,9 +90,9 @@ class MOEGather(Function):
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = moe_cuda.local_gather(grad_out.contiguous(), pos)
grad_out_buf = fmoe_cuda.local_gather(grad_out.contiguous(), pos)
if world_size > 1:
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
else:
......
from .fmoe_functions import *
import torch.nn as nn
import torch.nn.functional as F
class FMoELinear(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(FMoE, self).__init__()
super(FMoELinear, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -21,10 +22,11 @@ class FMoELinear(nn.Module):
return MOELinear.apply(inp, self.weight, fwd_expert_count)
class FMoENaiveGate(nn.module):
def __init__(self, num_expert=32, world_size=1, top_k=2):
class FMoENaiveGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super(FMoENaiveGate, self).__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp):
gate = self.gate(inp)
......@@ -53,7 +55,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
return x
class FMoETransformerMLP(nn.module):
class FMoETransformerMLP(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
......@@ -64,11 +66,12 @@ class FMoETransformerMLP(nn.module):
self.world_size = world_size
self.activation = activation
self.pre_lnorm = pre_lnorm
self.top_k = top_k
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.gate = FMoENaivegate(num_expert, world_size, top_k)
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
......
from torch import nn
from .moe import FMoE
from .moe_function import moe
from .fmoe import FMoETransformerMLP
class FFFN(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FFFN, self).__init__()
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.activation = activation
self.top_k = top_k
self.pre_lnorm = pre_lnorm
self.htoh4 = FMoE(num_expert, d_model, d_hidden,
world_size=world_size)
self.h4toh = FMoE(num_expert, d_hidden, d_model,
world_size=world_size)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = self.htoh4(inp, gate_top_k_idx)
x = self.activation(x)
x = self.h4toh(x, gate_top_k_idx)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
from .layers import FMoETransformerMLP
def create_moe_mlp(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment