moe_test.py 5.52 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
from moe import FMoE as MOELayer 
from moe import BruteForceMoE as MOELayer_raw
Rick Ho's avatar
Rick Ho committed
3
4
5
6
import torch
from torch import nn
import time
import sys
7
import os
Rick Ho's avatar
Rick Ho committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89


dev_name_default = 'cuda:0'


def perf():
    torch.manual_seed(42 + torch.distributed.get_rank())
    torch.cuda.manual_seed(42 + torch.distributed.get_rank())
    
    if len(sys.argv) == 6:
        batch_size = int(sys.argv[2])
        in_feat = int(sys.argv[3])
        out_feat = int(sys.argv[4])
        num_expert = int(sys.argv[5])
    else:
        batch_size = 4096
        in_feat = 1024
        out_feat = 4096
        num_expert = 4
    if torch.distributed.get_rank() == 0:
        print('Performance test case bs {} {}x{} ne {}'.format(batch_size,
            in_feat, out_feat, num_expert))
    if torch.distributed.get_world_size() > 1:
        dev_name = 'cuda'
    else:
        dev_name = dev_name_default

    inp = torch.rand(batch_size, in_feat).cuda(dev_name)
    gate = torch.randint(low=0, 
            high=num_expert * torch.distributed.get_world_size(), 
            size=(batch_size, ), requires_grad=False).int().cuda(dev_name)

    moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name)
    moe.train()

    o = moe(inp, gate)

    o = moe(inp, gate)
    o = moe(inp, gate)
    o = moe(inp, gate)

    n_runs = 16
    tott = 0.
    backt = 0.
    maxt = 0.
    sqtot = 0.
    for i in range(n_runs):
        gate = torch.randint(low=0, 
                high=num_expert * torch.distributed.get_world_size(), 
                size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
        ts = time.time()
        o = moe(inp, gate)
        te = time.time()

        loss = o.sum()

        bts = time.time()
        loss.backward()
        bte = time.time()

        tott += te - ts
        sqtot += (te - ts)**2
        maxt = max(maxt, te - ts)
        backt = bte - bts

    gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
    print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
        tott * 1e3 / n_runs, maxt * 1e3, 
        (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, 
        backt * 1e3 / n_runs, gflops))


def test_module(moe, linear, inp, gate):
    linear.zero_grad()
    moe.zero_grad()
    x = (linear(inp))
    output = moe(x, gate)
    y = output.mean()
    y.backward()
    return output, moe.weight.grad, linear.weight.grad, linear.bias.grad


90
91
92
93
rank = None
world_size = None


Rick Ho's avatar
Rick Ho committed
94
def test():
95
96
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Rick Ho's avatar
Rick Ho committed
97
98
99
100
101
102
103
    batch_size = 4
    num_expert = 2
    in_feat = 6
    out_feat = 7

    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
104
    moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
Rick Ho's avatar
Rick Ho committed
105
106
107
108
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
    if world_size == 1:
        moe_raw.weight.data = moe.weight.data.clone()
    else:
109
        weight_array = [torch.empty_like(moe.weight.data)
Rick Ho's avatar
Rick Ho committed
110
                for _ in range(world_size)]
111
112
        torch.distributed.all_gather(weight_array, moe.weight.data)
        moe_raw.weight.data = torch.cat(weight_array, dim=0)
Rick Ho's avatar
Rick Ho committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, 
            high=num_expert * world_size, 
            size=(batch_size,), 
            requires_grad=False).int().cuda()
    # gate = torch.Tensor([0, 1, 0, 1]).int().cuda()

    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())

    names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
    if world_size > 1:
        ou, wg, lwg, lbg = raw_out
        torch.distributed.all_reduce(wg)
        wg = wg[rank * num_expert:(rank + 1)* num_expert]
129
        raw_out = ou, wg, lwg, lbg
Rick Ho's avatar
Rick Ho committed
130
131
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
132
133
134
        print('Rank {} {} abs err {}'.format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write('=========== moe out ==============\n')
135
            sys.stderr.write('{}\n'.format(mo)) 
136
            sys.stderr.write('=========== raw out ==============\n')
137
            sys.stderr.write('{}\n'.format(ro)) 
138
            return
Rick Ho's avatar
Rick Ho committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165


def test_dp():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    batch_size = 6
    num_expert = 4
    in_feat = 2
    out_feat = 3

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()

    print("data parallel of a nn.Linear model")
    linear = nn.Linear(in_feat, in_feat).cuda()
    linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
    output = linear_dp(inp)
    print("successful!")

    print("data parallel of our MoE model")
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
    for i in range(5):
        output = moe_dp(inp, gate)


if __name__ == '__main__':
166
167
168
169
170
171
172
173
174
    os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0')
    os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1')
    if int(os.environ['WORLD_SIZE']) > 1:
        torch.distributed.init_process_group(backend='nccl')
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1
Rick Ho's avatar
Rick Ho committed
175
176
177
178
179
180
181
182
183
184
185
    if len(sys.argv) >= 2:
        task = sys.argv[1]
        print('Specificed task {}'.format(task))
        if task == 'correctness':
            test()
        elif task == 'dp':
            test_dp()
        elif task == 'performance':
            perf()
    else:
        test()