moe_test.py 1.81 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from moe import MOELayer, MOELayer_raw
Rick Ho's avatar
Rick Ho committed
2
3
import torch
import time
Rick Ho's avatar
Rick Ho committed
4
import sys
Rick Ho's avatar
Rick Ho committed
5
6


Rick Ho's avatar
Rick Ho committed
7
8
9
dev_name = 'cuda:0'


Rick Ho's avatar
Rick Ho committed
10
def perf():
Rick Ho's avatar
Rick Ho committed
11
12
13
    torch.manual_seed(42 + torch.distributed.get_rank())
    torch.cuda.manual_seed(42 + torch.distributed.get_rank())
    
Rick Ho's avatar
Rick Ho committed
14
    batch_size = int(sys.argv[1])
Rick Ho's avatar
Rick Ho committed
15
16
    in_feat = int(sys.argv[2])
    out_feat = int(sys.argv[3])
Rick Ho's avatar
Rick Ho committed
17
    num_expert = int(sys.argv[4])
Rick Ho's avatar
Rick Ho committed
18

19
    inp = torch.rand(batch_size, io_feat).cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
20
21
    gate = torch.randint(low=0, 
            high=num_expert * torch.distributed.get_world_size(), 
22
            size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
23

Rick Ho's avatar
Rick Ho committed
24
    moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
25
    moe.train()
Rick Ho's avatar
Rick Ho committed
26

Rick Ho's avatar
Rick Ho committed
27
    o = moe(inp, gate)
28

Rick Ho's avatar
Rick Ho committed
29
30
    o = moe(inp, gate)
    o = moe(inp, gate)
Rick Ho's avatar
Rick Ho committed
31
32
33
34
    o = moe(inp, gate)

    n_runs = 16
    tott = 0.
35
    backt = 0.
Rick Ho's avatar
Rick Ho committed
36
37
    maxt = 0.
    sqtot = 0.
Rick Ho's avatar
Rick Ho committed
38
    for i in range(n_runs):
Rick Ho's avatar
Rick Ho committed
39
40
        gate = torch.randint(low=0, 
                high=num_expert * torch.distributed.get_world_size(), 
41
                size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
42
43
44
        ts = time.time()
        o = moe(inp, gate)
        te = time.time()
45
46
47
48
49
50
51

        loss = o.sum()

        bts = time.time()
        loss.backward()
        bte = time.time()

Rick Ho's avatar
Rick Ho committed
52
        tott += te - ts
Rick Ho's avatar
Rick Ho committed
53
54
        sqtot += (te - ts)**2
        maxt = max(maxt, te - ts)
55
        backt = bte - bts
Rick Ho's avatar
Rick Ho committed
56

Rick Ho's avatar
Rick Ho committed
57
    gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
58
    print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
Rick Ho's avatar
Rick Ho committed
59
        tott * 1e3 / n_runs, maxt * 1e3, 
60
61
        (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, 
        backt * 1e3 / n_runs, gflops))
Rick Ho's avatar
Rick Ho committed
62
63
64


if __name__ == '__main__':
Rick Ho's avatar
Rick Ho committed
65
66
    torch.distributed.init_process_group(backend='mpi')
    # print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
Rick Ho's avatar
Rick Ho committed
67
    perf()