Commit c5556037 authored by Rick Ho's avatar Rick Ho
Browse files

newer test with top-k bug fixed

parent 103343ca
......@@ -77,7 +77,6 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
return x
class FMoE(nn.Module):
r'''
A general moe implementation that supports an arbitrary module as the expert
......
......@@ -50,7 +50,8 @@ class FMoETransformerMLP(FMoE):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
world_size=world_size, mp_group=mp_group, expert_fn=expert_fn)
top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
......
from moe import FMoE as MOELayer
from fmoe import FMoETransformerMLP as MOELayer
import torch
import time
import sys
......@@ -10,29 +10,29 @@ world_size = None
dev_name_default = 'cuda:0'
def test_performance(batch_size, in_feat, out_feat, num_expert):
def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if rank == 0:
print('Performance test case bs {} {}x{} ne {}x{}'.format(
batch_size, in_feat, out_feat, world_size, num_expert))
print('Performance test case bs {} {}x{} ne {}x{} topk {}'.format(
batch_size, in_feat, hidden_feat, world_size, num_expert, top_k))
if world_size > 1:
dev_name = 'cuda'
else:
dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name)
gate = torch.randint(low=0,
high=num_expert * world_size,
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
inp.requires_grad = True
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name)
moe = MOELayer(num_expert=num_expert,
d_model=in_feat, d_hidden=hidden_feat,
world_size=world_size, top_k=top_k).cuda(dev_name)
moe.train()
# warm up
for _ in range(4):
_ = moe(inp, gate)
_ = moe(inp)
n_runs = 16
tott = 0.
......@@ -40,11 +40,8 @@ def test_performance(batch_size, in_feat, out_feat, num_expert):
maxt = 0.
sqtot = 0.
for i in range(n_runs):
gate = torch.randint(low=0,
high=num_expert * world_size,
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
ts = time.time()
o = moe(inp, gate)
o = moe(inp)
te = time.time()
loss = o.sum()
......@@ -58,10 +55,11 @@ def test_performance(batch_size, in_feat, out_feat, num_expert):
maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 +
batch_size * in_feat * num_expert) / tott
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 * top_k / n_runs,
backt * 1e3 / n_runs, gflops))
......@@ -76,4 +74,4 @@ if __name__ == '__main__':
rank = 0
world_size = 1
test_performance(4096, 1024, 4096, 8)
test_performance(4096, 1024, 4096, 8, 8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment