"examples/offline_inference/basic/classify.py" did not exist on "6f676d33f5c04974f87d64060ff5df0e963bc517"
moe_test.py 1.27 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
from moe import MOELayer
import torch
import time
Rick Ho's avatar
Rick Ho committed
4
import sys
Rick Ho's avatar
Rick Ho committed
5
6


Rick Ho's avatar
Rick Ho committed
7
8
9
dev_name = 'cuda:0'


Rick Ho's avatar
Rick Ho committed
10
def perf():
Rick Ho's avatar
Rick Ho committed
11
    batch_size = int(sys.argv[1])
Rick Ho's avatar
Rick Ho committed
12
13
    in_feat = int(sys.argv[2])
    out_feat = int(sys.argv[3])
Rick Ho's avatar
Rick Ho committed
14
    num_expert = int(sys.argv[4])
Rick Ho's avatar
Rick Ho committed
15
16


Rick Ho's avatar
Rick Ho committed
17
18
19
    inp = torch.rand(batch_size, in_feat).cuda(dev_name)
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), 
            requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
20

Rick Ho's avatar
Rick Ho committed
21
    moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
22

Rick Ho's avatar
Rick Ho committed
23
24
25
26
27
    o = moe(inp, gate)
    o = moe(inp, gate)
    o = moe(inp, gate)
    o = moe(inp, gate)
    o = moe(inp, gate)
Rick Ho's avatar
Rick Ho committed
28
29
30
31
    o = moe(inp, gate)

    n_runs = 16
    tott = 0.
Rick Ho's avatar
Rick Ho committed
32
33
    maxt = 0.
    sqtot = 0.
Rick Ho's avatar
Rick Ho committed
34
    for i in range(n_runs):
Rick Ho's avatar
Rick Ho committed
35
36
        gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), 
                requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
37
38
39
40
        ts = time.time()
        o = moe(inp, gate)
        te = time.time()
        tott += te - ts
Rick Ho's avatar
Rick Ho committed
41
42
        sqtot += (te - ts)**2
        maxt = max(maxt, te - ts)
Rick Ho's avatar
Rick Ho committed
43

Rick Ho's avatar
Rick Ho committed
44
    gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
Rick Ho's avatar
Rick Ho committed
45
46
47
    print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
        tott * 1e3 / n_runs, maxt * 1e3, 
        (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops))
Rick Ho's avatar
Rick Ho committed
48
49
50
51


if __name__ == '__main__':
    perf()