"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "893098888e6a0214a15d8c59f5719444a31bc0d0"
Commit c5556037 authored by Rick Ho's avatar Rick Ho
Browse files

newer test with top-k bug fixed

parent 103343ca
...@@ -77,7 +77,6 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -77,7 +77,6 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
return x return x
class FMoE(nn.Module): class FMoE(nn.Module):
r''' r'''
A general moe implementation that supports an arbitrary module as the expert A general moe implementation that supports an arbitrary module as the expert
......
...@@ -50,7 +50,8 @@ class FMoETransformerMLP(FMoE): ...@@ -50,7 +50,8 @@ class FMoETransformerMLP(FMoE):
def expert_fn(inp, gate): def expert_fn(inp, gate):
return self.experts(inp, gate) return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
world_size=world_size, mp_group=mp_group, expert_fn=expert_fn) top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
......
from moe import FMoE as MOELayer from fmoe import FMoETransformerMLP as MOELayer
import torch import torch
import time import time
import sys import sys
...@@ -10,29 +10,29 @@ world_size = None ...@@ -10,29 +10,29 @@ world_size = None
dev_name_default = 'cuda:0' dev_name_default = 'cuda:0'
def test_performance(batch_size, in_feat, out_feat, num_expert): def test_performance(batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
if rank == 0: if rank == 0:
print('Performance test case bs {} {}x{} ne {}x{}'.format( print('Performance test case bs {} {}x{} ne {}x{} topk {}'.format(
batch_size, in_feat, out_feat, world_size, num_expert)) batch_size, in_feat, hidden_feat, world_size, num_expert, top_k))
if world_size > 1: if world_size > 1:
dev_name = 'cuda' dev_name = 'cuda'
else: else:
dev_name = dev_name_default dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name) inp = torch.rand(batch_size, in_feat).cuda(dev_name)
gate = torch.randint(low=0, inp.requires_grad = True
high=num_expert * world_size,
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name) moe = MOELayer(num_expert=num_expert,
d_model=in_feat, d_hidden=hidden_feat,
world_size=world_size, top_k=top_k).cuda(dev_name)
moe.train() moe.train()
# warm up # warm up
for _ in range(4): for _ in range(4):
_ = moe(inp, gate) _ = moe(inp)
n_runs = 16 n_runs = 16
tott = 0. tott = 0.
...@@ -40,11 +40,8 @@ def test_performance(batch_size, in_feat, out_feat, num_expert): ...@@ -40,11 +40,8 @@ def test_performance(batch_size, in_feat, out_feat, num_expert):
maxt = 0. maxt = 0.
sqtot = 0. sqtot = 0.
for i in range(n_runs): for i in range(n_runs):
gate = torch.randint(low=0,
high=num_expert * world_size,
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
ts = time.time() ts = time.time()
o = moe(inp, gate) o = moe(inp)
te = time.time() te = time.time()
loss = o.sum() loss = o.sum()
...@@ -58,10 +55,11 @@ def test_performance(batch_size, in_feat, out_feat, num_expert): ...@@ -58,10 +55,11 @@ def test_performance(batch_size, in_feat, out_feat, num_expert):
maxt = max(maxt, te - ts) maxt = max(maxt, te - ts)
backt = bte - bts backt = bte - bts
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott gflops = 2e-9 * n_runs * (in_feat * hidden_feat * batch_size * top_k * 2 +
batch_size * in_feat * num_expert) / tott
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format( print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3, tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, (sqtot / n_runs - (tott / n_runs)**2) * 1e3 * top_k / n_runs,
backt * 1e3 / n_runs, gflops)) backt * 1e3 / n_runs, gflops))
...@@ -76,4 +74,4 @@ if __name__ == '__main__': ...@@ -76,4 +74,4 @@ if __name__ == '__main__':
rank = 0 rank = 0
world_size = 1 world_size = 1
test_performance(4096, 1024, 4096, 8) test_performance(4096, 1024, 4096, 8, 8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment