Commit 101b847c authored by Rick Ho's avatar Rick Ho
Browse files

adapt benchmark to new moe module

parent 8dac1a52
...@@ -20,9 +20,8 @@ class BruteForceMoE(nn.Module): ...@@ -20,9 +20,8 @@ class BruteForceMoE(nn.Module):
gate=NaiveGate, top_k=1, pre_lnorm=False): gate=NaiveGate, top_k=1, pre_lnorm=False):
assert world_size == 1, 'Distributed brute force is not supported' assert world_size == 1, 'Distributed brute force is not supported'
super().__init__() super().__init__()
self.mlp1 = BruteForceMoELinear(num_expert, d_model, d_hidden, 1) self.mlp = BruteForceMoELinear(activation, num_expert, d_model,
self.mlp2 = BruteForceMoELinear(num_expert, d_hidden, d_model, 1) d_hidden, 1, top_k)
self.activation = activation
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -34,11 +33,7 @@ class BruteForceMoE(nn.Module): ...@@ -34,11 +33,7 @@ class BruteForceMoE(nn.Module):
inp = self.layer_norm(inp) inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
inp = inp.repeat_interleave(repeats=self.top_k, dim=0) inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = self.mlp1(inp, gate_top_k_idx) x = self.mlp(inp, gate_top_k_idx, gate_score)
x = self.activation(x)
x = self.mlp2(x, gate_top_k_idx)
x = x.view(-1, self.top_k, self.d_model)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if not self.pre_lnorm: if not self.pre_lnorm:
x = self.layer_norm(x) x = self.layer_norm(x)
return x return x
...@@ -47,7 +42,6 @@ class BruteForceMoE(nn.Module): ...@@ -47,7 +42,6 @@ class BruteForceMoE(nn.Module):
def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k): def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
if rank == 0: if rank == 0:
print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}' print('Performance test of {} mm size {} {}x{} experts {}x{} topk {}'
.format(MOELayer.__name__, batch_size, in_feat, hidden_feat, .format(MOELayer.__name__, batch_size, in_feat, hidden_feat,
...@@ -108,11 +102,10 @@ if __name__ == '__main__': ...@@ -108,11 +102,10 @@ if __name__ == '__main__':
else: else:
rank = 0 rank = 0
world_size = 1 world_size = 1
batch_size = int(os.environ.get('BATCH_SIZE', '4096')) batch_size = int(os.environ.get('BATCH_SIZE', '4096'))
d_model = int(os.environ.get('D_MODEL', '1024')) d_model = int(os.environ.get('D_MODEL', '1024'))
d_hidden = int(os.environ.get('D_HIDDEN', '4096')) d_hidden = int(os.environ.get('D_HIDDEN', '4096'))
num_expert = int(os.environ.get('NUM_EXPERT', '8')) num_expert = int(os.environ.get('NUM_EXPERT', '64'))
top_k = int(os.environ.get('TOP_K', '2')) top_k = int(os.environ.get('TOP_K', '2'))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model, benchmark_mlp(FMoETransformerMLP, batch_size, d_model,
d_hidden, num_expert, top_k) d_hidden, num_expert, top_k)
......
...@@ -40,7 +40,7 @@ class BruteForceMoELinear(nn.Module): ...@@ -40,7 +40,7 @@ class BruteForceMoELinear(nn.Module):
x = torch.bmm(gate_score, o.view(-1, self.top_k, x = torch.bmm(gate_score, o.view(-1, self.top_k,
self.d_model)).reshape(-1, self.d_model) self.d_model)).reshape(-1, self.d_model)
return x return x
class BruteForceMoE(nn.Module): class BruteForceMoE(nn.Module):
def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2): def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
......
#!/bin/bash
runtest() {
echo Testing $@
$@
if [ $? = 0 ]
then
echo '----------------- Passed'
else
echo '----------------- Failed'
exit
fi
}
if [ ! -z $1 ]
then
runtest $@
exit
fi
TEST_SCRIPT=$(dirname $(realpath $0))/test.sh
runtest $TEST_SCRIPT tests/test_numerical.py
runtest mpirun -n 2 $TEST_SCRIPT tests/test_numerical.py
runtest $TEST_SCRIPT tests/test_dp.py
runtest $TEST_SCRIPT tests/test_performance.py
runtest mpirun -n 2 $TEST_SCRIPT tests/test_performance.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment