moe_test.py 5.17 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from moe import MOELayer, MOELayer_raw
Rick Ho's avatar
Rick Ho committed
2
import torch
Rick Ho's avatar
Rick Ho committed
3
from torch import nn
Rick Ho's avatar
Rick Ho committed
4
import time
Rick Ho's avatar
Rick Ho committed
5
import sys
Rick Ho's avatar
Rick Ho committed
6
7


Rick Ho's avatar
Rick Ho committed
8
dev_name_default = 'cuda:0'
Rick Ho's avatar
Rick Ho committed
9
10


Rick Ho's avatar
Rick Ho committed
11
def perf():
Rick Ho's avatar
Rick Ho committed
12
13
14
    torch.manual_seed(42 + torch.distributed.get_rank())
    torch.cuda.manual_seed(42 + torch.distributed.get_rank())
    
Rick Ho's avatar
Rick Ho committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    if len(sys.argv) == 6:
        batch_size = int(sys.argv[2])
        in_feat = int(sys.argv[3])
        out_feat = int(sys.argv[4])
        num_expert = int(sys.argv[5])
    else:
        batch_size = 4096
        in_feat = 1024
        out_feat = 4096
        num_expert = 4
    if torch.distributed.get_rank() == 0:
        print('Performance test case bs {} {}x{} ne {}'.format(batch_size,
            in_feat, out_feat, num_expert))
    if torch.distributed.get_world_size() > 1:
        dev_name = 'cuda'
    else:
        dev_name = dev_name_default
Rick Ho's avatar
Rick Ho committed
32

33
    inp = torch.rand(batch_size, in_feat).cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
34
35
    gate = torch.randint(low=0, 
            high=num_expert * torch.distributed.get_world_size(), 
36
            size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
37

Rick Ho's avatar
Rick Ho committed
38
    moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name)
39
    moe.train()
Rick Ho's avatar
Rick Ho committed
40

Rick Ho's avatar
Rick Ho committed
41
    o = moe(inp, gate)
42

Rick Ho's avatar
Rick Ho committed
43
44
    o = moe(inp, gate)
    o = moe(inp, gate)
Rick Ho's avatar
Rick Ho committed
45
46
47
48
    o = moe(inp, gate)

    n_runs = 16
    tott = 0.
49
    backt = 0.
Rick Ho's avatar
Rick Ho committed
50
51
    maxt = 0.
    sqtot = 0.
Rick Ho's avatar
Rick Ho committed
52
    for i in range(n_runs):
Rick Ho's avatar
Rick Ho committed
53
54
        gate = torch.randint(low=0, 
                high=num_expert * torch.distributed.get_world_size(), 
55
                size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
56
57
58
        ts = time.time()
        o = moe(inp, gate)
        te = time.time()
59
60
61
62
63
64
65

        loss = o.sum()

        bts = time.time()
        loss.backward()
        bte = time.time()

Rick Ho's avatar
Rick Ho committed
66
        tott += te - ts
Rick Ho's avatar
Rick Ho committed
67
68
        sqtot += (te - ts)**2
        maxt = max(maxt, te - ts)
69
        backt = bte - bts
Rick Ho's avatar
Rick Ho committed
70

Rick Ho's avatar
Rick Ho committed
71
    gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
72
    print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
Rick Ho's avatar
Rick Ho committed
73
        tott * 1e3 / n_runs, maxt * 1e3, 
74
75
        (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, 
        backt * 1e3 / n_runs, gflops))
Rick Ho's avatar
Rick Ho committed
76
77


Rick Ho's avatar
Rick Ho committed
78
79
80
81
82
def test_module(moe, linear, inp, gate):
    linear.zero_grad()
    moe.zero_grad()
    x = (linear(inp))
    output = moe(x, gate)
Rick Ho's avatar
Rick Ho committed
83
    # print('ooutput', torch.distributed.get_rank(), output)
Rick Ho's avatar
Rick Ho committed
84
85
86
87
88
89
    y = output.mean()
    y.backward()
    return output, moe.weight.grad, linear.weight.grad, linear.bias.grad


def test():
Rick Ho's avatar
Rick Ho committed
90
91
    torch.manual_seed(42 + torch.distributed.get_rank())
    torch.cuda.manual_seed(42 + torch.distributed.get_rank())
Rick Ho's avatar
Rick Ho committed
92
93
94
95
96
97
98
    batch_size = 4
    num_expert = 2
    in_feat = 6
    out_feat = 7

    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
99
100
101
102
    if world_size > 1:
        moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
    else:
        moe = MOELayer(num_expert, in_feat, out_feat).cuda()
Rick Ho's avatar
Rick Ho committed
103
104
105
106
107
108
109
110
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
    if world_size == 1:
        moe_raw.weight.data = moe.weight.data.clone()
    else:
        weight_array = [torch.empty_like(moe.weight.data).cpu() 
                for _ in range(world_size)]
        torch.distributed.all_gather(weight_array, moe.weight.data.cpu())
        moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda()
Rick Ho's avatar
Rick Ho committed
111
112
113

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, 
Rick Ho's avatar
Rick Ho committed
114
115
            high=num_expert * world_size, 
            size=(batch_size,), 
Rick Ho's avatar
Rick Ho committed
116
117
118
119
120
121
            requires_grad=False).int().cuda()
    # gate = torch.Tensor([0, 1, 0, 1]).int().cuda()

    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())

Rick Ho's avatar
Rick Ho committed
122
123
124
125
126
127
128
129
    names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
    if world_size > 1:
        rank = torch.distributed.get_rank()
        ou, wg, lwg, lbg = raw_out
        wg = wg.cpu()
        torch.distributed.all_reduce(wg)
        wg = wg[rank * num_expert:(rank + 1)* num_expert]
        raw_out = ou, wg.cuda(), lwg, lbg
Rick Ho's avatar
Rick Ho committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
        print('{} abs err {}'.format(name, err))


def test_dp():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    batch_size = 6
    num_expert = 4
    in_feat = 2
    out_feat = 3

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()

    print("data parallel of a nn.Linear model")
    linear = nn.Linear(in_feat, in_feat).cuda()
    linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
    output = linear_dp(inp)
    print("successful!")

    print("data parallel of our MoE model")
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
    for i in range(5):
        output = moe_dp(inp, gate)


Rick Ho's avatar
Rick Ho committed
159
if __name__ == '__main__':
Rick Ho's avatar
Rick Ho committed
160
    torch.distributed.init_process_group(backend='mpi')
Rick Ho's avatar
Rick Ho committed
161
    world_size = torch.distributed.get_world_size()
Rick Ho's avatar
Rick Ho committed
162
163
164
165
166
167
168
169
170
171
172
    if len(sys.argv) == 2:
        task = sys.argv[1]
        print('Specificed task {}'.format(task))
        if task == 'correctness':
            test()
        elif task == 'dp':
            test_dp()
        elif task == 'performance':
            perf()
    else:
        test()