Commit 58e949cf authored by Rick Ho's avatar Rick Ho
Browse files

initial version to run with megatron

parent f866ed0f
...@@ -5,8 +5,9 @@ from .moe import FFFN ...@@ -5,8 +5,9 @@ from .moe import FFFN
def create_moe_mlp(args): def create_moe_mlp(args):
assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size' assert args.num_experts % args.model_parallel_size == 0, 'Num experts should be multiple of mp size'
num_experts = args.num_experts // args.model_parallel_size num_experts = args.num_experts // args.model_parallel_size
fmoe = FFFN(num_experts, in_feat=args.hidden_size, fmoe = FFFN(num_experts,
hidden_feat=args.hidden_size * 4, out_feat=args.hidden_size, d_model=args.hidden_size,
world_size = args.model_parallel_size) d_hidden=args.hidden_size * 4,
world_size=args.model_parallel_size)
return fmoe return fmoe
import math import math
from torch import nn from torch import nn
import torch import torch
import torch.nn.functional as F
from .moe_function import moe from .moe_function import moe
...@@ -27,20 +28,55 @@ class FMoE(nn.Module): ...@@ -27,20 +28,55 @@ class FMoE(nn.Module):
class FFFN(nn.Module): class FFFN(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
out_feat=1024, world_size=None, activation=torch.nn.functional.gelu): world_size=None, activation=torch.nn.functional.gelu,
top_k=2, pre_lnorm=False):
super(FFFN, self).__init__() super(FFFN, self).__init__()
self.htoh4 = FMoE(num_expert, in_feat, hidden_feat, self.d_model = d_model
world_size=world_size) self.d_hidden = d_hidden
self.world_size = world_size
self.activation = activation self.activation = activation
self.h4toh = FMoE(num_expert, hidden_feat, out_feat, self.top_k = top_k
self.pre_lnorm = pre_lnorm
self.htoh4 = FMoE(num_expert, d_model, d_hidden,
world_size=world_size) world_size=world_size)
self.h4toh = FMoE(num_expert, d_hidden, d_model,
world_size=world_size)
self.gate = nn.Linear(d_model, num_expert)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
dtype=torch.float32))
def forward(self, inp, gate): def forward(self, inp):
x = self.htoh4(inp) # import pdb; pdb.set_trace()
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k,
dim=0) # (BxLxtop_k) x d_model
x = self.htoh4(inp, gate_top_k_idx)
x = self.activation(x) x = self.activation(x)
x = self.h4toh(x) x = self.h4toh(x, gate_top_k_idx)
return x
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output, self.bias
class BruteForceMoE(nn.Module): class BruteForceMoE(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment