Commit 437afda2 authored by Rick Ho's avatar Rick Ho
Browse files

reconstruct fmoe nn module

parent 5e0af68d
from .moe import FMoE, BruteForceMoE
from .moe import BruteForceMoE
from .fmoe import FMoELinear, FMoENaiveGate, FMoETransformerMLP
from .fmoe_functions import *
import torch.nn as nn
class FMoELinear(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(FMoE, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, inp, fwd_expert_count):
return MOELinear.apply(inp, self.weight, fwd_expert_count)
class FMoENaiveGate(nn.module):
def __init__(self, num_expert=32, world_size=1, top_k=2):
super(FMoENaiveGate, self).__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
def forward(self, inp):
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
(pos, local_expert_count, global_expert_count, fwd_expert_count,
fwd_batch_size) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(inp, pos, local_expert_count, global_expert_count,
fwd_batch_size, world_size)
for i, l in enumerate(linears):
if i:
x = activation(x)
x = l(x)
x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
inp.shape[0], world_size)
return x
class FMoETransformerMLP(nn.module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FMoETransformerMLP, self).__init__()
self.num_expert = num_expert
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.activation = activation
self.pre_lnorm = pre_lnorm
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.gate = FMoENaivegate(num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
gate_top_k_idx, gate_score = self.gate(inp)
x = _fmoe_full_forward(inp, gate_top_k_idx,
[self.htoh4, self.h4toh], self.activation,
self.num_expert, self.world_size)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
from torch import nn
from .moe import FFFN
from .moe import FMoE
from .moe_function import moe
from .fmoe import FMoETransformerMLP
class FFFN(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FFFN, self).__init__()
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.activation = activation
self.top_k = top_k
self.pre_lnorm = pre_lnorm
self.htoh4 = FMoE(num_expert, d_model, d_hidden,
world_size=world_size)
self.h4toh = FMoE(num_expert, d_hidden, d_model,
world_size=world_size)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = self.htoh4(inp, gate_top_k_idx)
x = self.activation(x)
x = self.h4toh(x, gate_top_k_idx)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
def create_moe_mlp(args):
assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size'
num_experts = args.num_experts // args.model_parallel_size
fmoe = FFFN(num_experts,
fmoe = FMoETransformerMLP(num_experts,
d_model=args.hidden_size,
d_hidden=args.hidden_size * 4,
world_size=args.model_parallel_size)
......
......@@ -27,58 +27,6 @@ class FMoE(nn.Module):
return moe(inp, gate.int(), self.weight, self.world_size)
class FFFN(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FFFN, self).__init__()
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.activation = activation
self.top_k = top_k
self.pre_lnorm = pre_lnorm
self.htoh4 = FMoE(num_expert, d_model, d_hidden,
world_size=world_size)
self.h4toh = FMoE(num_expert, d_hidden, d_model,
world_size=world_size)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
def forward(self, inp):
# import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = self.htoh4(inp, gate_top_k_idx)
x = self.activation(x)
x = self.h4toh(x, gate_top_k_idx)
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment