Unverified Commit 1a72a0cb authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #3 from laekov/laekov/benchmarks

separate benchmark with tests
parents 34477955 101b847c
from fmoe import FMoETransformerMLP as MOELayer
from fmoe import FMoETransformerMLP
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import torch
import torch.nn as nn
import time
import sys
import os
......@@ -10,13 +13,39 @@ world_size = None
dev_name_default = 'cuda:0'
def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, d_model=1024, d_hidden=4096,
world_size=1, mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate, top_k=1, pre_lnorm=False):
assert world_size == 1, 'Distributed brute force is not supported'
super().__init__()
self.mlp = BruteForceMoELinear(activation, num_expert, d_model,
d_hidden, 1, top_k)
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self, inp):
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp)
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = self.mlp(inp, gate_top_k_idx, gate_score)
if not self.pre_lnorm:
x = self.layer_norm(x)
return x
def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if rank == 0:
print('Performance test case bs {} {}x{} ne {}x{} topk {}'.format(
batch_size, in_feat, hidden_feat, world_size, num_expert, top_k))
print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}'
.format(MOELayer.__name__, batch_size, in_feat, hidden_feat,
world_size, num_expert, top_k))
if world_size > 1:
dev_name = 'cuda'
else:
......@@ -53,7 +82,7 @@ def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
backt = bte - bts
backt += bte - bts
gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 +
batch_size * in_feat * num_expert) / tott
......@@ -73,5 +102,13 @@ if __name__ == '__main__':
else:
rank = 0
world_size = 1
test_performance(4096, 1024, 4096, 8, 8)
batch_size = int(os.environ.get('BATCH_SIZE', '4096'))
d_model = int(os.environ.get('D_MODEL', '1024'))
d_hidden = int(os.environ.get('D_HIDDEN', '4096'))
num_expert = int(os.environ.get('NUM_EXPERT', '64'))
top_k = int(os.environ.get('TOP_K', '2'))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model,
d_hidden, num_expert, top_k)
if world_size == 1:
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert,
top_k)
......@@ -28,14 +28,17 @@ class BruteForceMoELinear(nn.Module):
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
t = inp[i] @ self.weight_htoh4[gate_long[i]].t()
t = self.activation(t)
x[i] = t @ self.weight_h4toh[gate_long[i]].t()
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype,
device=inp.device)
for i in range(self.weight_htoh4.shape[0]):
idx = (gate_idx == i)
x = inp[idx]
x = x @ self.weight_htoh4[i].t()
x = self.activation(x)
x = x @ self.weight_h4toh[i].t()
o[idx] = x
x = torch.bmm(gate_score, o.view(-1, self.top_k,
self.d_model)).reshape(-1, self.d_model)
return x
......
#!/bin/bash
runtest() {
echo Testing $@
$@
if [ $? = 0 ]
then
echo '----------------- Passed'
else
echo '----------------- Failed'
exit
fi
}
if [ ! -z $1 ]
then
runtest $@
exit
fi
TEST_SCRIPT=$(dirname $(realpath $0))/test.sh
runtest $TEST_SCRIPT tests/test_numerical.py
runtest mpirun -n 2 $TEST_SCRIPT tests/test_numerical.py
runtest $TEST_SCRIPT tests/test_dp.py
runtest $TEST_SCRIPT tests/test_performance.py
runtest mpirun -n 2 $TEST_SCRIPT tests/test_performance.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment