Commit d2678111 authored by Rick Ho's avatar Rick Ho
Browse files

a stronger benchmark

parent c5556037
from fmoe import FMoETransformerMLP as MOELayer
from fmoe import FMoETransformerMLP
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import torch
import torch.nn as nn
import time
import sys
import os
......@@ -10,13 +13,45 @@ world_size = None
dev_name_default = 'cuda:0'
def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=1, mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate, top_k=1, pre_lnorm=False):
assert world_size == 1, 'Distributed brute force is not supported'
super().__init__()
self.mlp1 = BruteForceMoELinear(num_expert, d_model, d_hidden, 1)
self.mlp2 = BruteForceMoELinear(num_expert, d_hidden, d_model, 1)
self.activation = activation
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self, inp):
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp)
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = self.mlp1(inp, gate_top_k_idx)
x = self.activation(x)
x = self.mlp2(x, gate_top_k_idx)
x = x.view(-1, self.top_k, self.d_model)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if not self.pre_lnorm:
x = self.layer_norm(x)
return x
def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if rank == 0:
print('Performance test case bs {} {}x{} ne {}x{} topk {}'.format(
batch_size, in_feat, hidden_feat, world_size, num_expert, top_k))
print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}'
.format(MOELayer.__name__, batch_size, in_feat, hidden_feat,
world_size, num_expert, top_k))
if world_size > 1:
dev_name = 'cuda'
else:
......@@ -53,7 +88,7 @@ def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
backt = bte - bts
backt += bte - bts
gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 +
batch_size * in_feat * num_expert) / tott
......@@ -74,4 +109,13 @@ if __name__ == '__main__':
rank = 0
world_size = 1
test_performance(4096, 1024, 4096, 8, 8)
batch_size = int(os.environ.get('BATCH_SIZE', '4096'))
d_model = int(os.environ.get('D_MODEL', '1024'))
d_hidden = int(os.environ.get('D_HIDDEN', '4096'))
num_expert = int(os.environ.get('NUM_EXPERT', '8'))
top_k = int(os.environ.get('TOP_K', '2'))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model,
d_hidden, num_expert, top_k)
if world_size == 1:
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert,
top_k)
......@@ -3,33 +3,11 @@ from torch import nn
import torch
import torch.nn.functional as F
from fmoe.layers import FMoELinear, _fmoe_full_forward
class FMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=1):
super(FMoE, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.linear = FMoELinear(num_expert, in_feat, out_feat)
self.weight = self.linear.weight
self.reset_parameters()
def reset_parameters(self):
self.linear.reset_parameters()
def forward(self, inp, gate):
return _fmoe_full_forward(inp, gate, [self.linear], None,
self.num_expert, self.world_size)
class BruteForceMoE(nn.Module):
class BruteForceMoELinear(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
super(BruteForceMoE, self).__init__()
super(BruteForceMoELinear, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -42,13 +20,16 @@ class BruteForceMoE(nn.Module):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
gate_long = gate.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
for i in range(batch_size):
x[i] = inp[i] @ self.weight[gate_long[i]].t()
return x
o = torch.empty(batch_size, self.out_feat, dtype=inp.dtype,
device=inp.device)
for i in range(self.num_expert):
idx = (gate == i)
x = inp[idx]
x = x @ self.weight[i].t()
o[idx] = x
return o
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment