Commit d2678111 authored by Rick Ho's avatar Rick Ho
Browse files

a stronger benchmark

parent c5556037
from fmoe import FMoETransformerMLP as MOELayer from fmoe import FMoETransformerMLP
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import torch import torch
import torch.nn as nn
import time import time
import sys import sys
import os import os
...@@ -10,13 +13,45 @@ world_size = None ...@@ -10,13 +13,45 @@ world_size = None
dev_name_default = 'cuda:0' dev_name_default = 'cuda:0'
def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k): class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=1, mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate, top_k=1, pre_lnorm=False):
assert world_size == 1, 'Distributed brute force is not supported'
super().__init__()
self.mlp1 = BruteForceMoELinear(num_expert, d_model, d_hidden, 1)
self.mlp2 = BruteForceMoELinear(num_expert, d_hidden, d_model, 1)
self.activation = activation
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self, inp):
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp)
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = self.mlp1(inp, gate_top_k_idx)
x = self.activation(x)
x = self.mlp2(x, gate_top_k_idx)
x = x.view(-1, self.top_k, self.d_model)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if not self.pre_lnorm:
x = self.layer_norm(x)
return x
def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
if rank == 0: if rank == 0:
print('Performance test case bs {} {}x{} ne {}x{} topk {}'.format( print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}'
batch_size, in_feat, hidden_feat, world_size, num_expert, top_k)) .format(MOELayer.__name__, batch_size, in_feat, hidden_feat,
world_size, num_expert, top_k))
if world_size > 1: if world_size > 1:
dev_name = 'cuda' dev_name = 'cuda'
else: else:
...@@ -53,7 +88,7 @@ def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k): ...@@ -53,7 +88,7 @@ def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
tott += te - ts tott += te - ts
sqtot += (te - ts)**2 sqtot += (te - ts)**2
maxt = max(maxt, te - ts) maxt = max(maxt, te - ts)
backt = bte - bts backt += bte - bts
gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 + gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 +
batch_size * in_feat * num_expert) / tott batch_size * in_feat * num_expert) / tott
...@@ -74,4 +109,13 @@ if __name__ == '__main__': ...@@ -74,4 +109,13 @@ if __name__ == '__main__':
rank = 0 rank = 0
world_size = 1 world_size = 1
test_performance(4096, 1024, 4096, 8, 8) batch_size = int(os.environ.get('BATCH_SIZE', '4096'))
d_model = int(os.environ.get('D_MODEL', '1024'))
d_hidden = int(os.environ.get('D_HIDDEN', '4096'))
num_expert = int(os.environ.get('NUM_EXPERT', '8'))
top_k = int(os.environ.get('TOP_K', '2'))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model,
d_hidden, num_expert, top_k)
if world_size == 1:
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert,
top_k)
...@@ -3,33 +3,11 @@ from torch import nn ...@@ -3,33 +3,11 @@ from torch import nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fmoe.layers import FMoELinear, _fmoe_full_forward
class BruteForceMoELinear(nn.Module):
class FMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=1):
super(FMoE, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.linear = FMoELinear(num_expert, in_feat, out_feat)
self.weight = self.linear.weight
self.reset_parameters()
def reset_parameters(self):
self.linear.reset_parameters()
def forward(self, inp, gate):
return _fmoe_full_forward(inp, gate, [self.linear], None,
self.num_expert, self.world_size)
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0): world_size=0):
super(BruteForceMoE, self).__init__() super(BruteForceMoELinear, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -42,13 +20,16 @@ class BruteForceMoE(nn.Module): ...@@ -42,13 +20,16 @@ class BruteForceMoE(nn.Module):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat) out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
gate_long = gate.long() gate_long = gate.long()
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat)) o = torch.empty(batch_size, self.out_feat, dtype=inp.dtype,
for i in range(batch_size): device=inp.device)
x[i] = inp[i] @ self.weight[gate_long[i]].t() for i in range(self.num_expert):
return x idx = (gate == i)
x = inp[idx]
x = x @ self.weight[i].t()
o[idx] = x
return o
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment