moe_test.py 4.31 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from moe import MOELayer, MOELayer_raw
Rick Ho's avatar
Rick Ho committed
2
import torch
Rick Ho's avatar
Rick Ho committed
3
from torch import nn
Rick Ho's avatar
Rick Ho committed
4
import time
Rick Ho's avatar
Rick Ho committed
5
import sys
Rick Ho's avatar
Rick Ho committed
6
7


8
dev_name = 'cuda:1'
Rick Ho's avatar
Rick Ho committed
9
10


Rick Ho's avatar
Rick Ho committed
11
def perf():
Rick Ho's avatar
Rick Ho committed
12
13
14
    torch.manual_seed(42 + torch.distributed.get_rank())
    torch.cuda.manual_seed(42 + torch.distributed.get_rank())
    
Rick Ho's avatar
Rick Ho committed
15
    batch_size = int(sys.argv[1])
Rick Ho's avatar
Rick Ho committed
16
17
    in_feat = int(sys.argv[2])
    out_feat = int(sys.argv[3])
Rick Ho's avatar
Rick Ho committed
18
    num_expert = int(sys.argv[4])
Rick Ho's avatar
Rick Ho committed
19

20
    inp = torch.rand(batch_size, in_feat).cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
21
22
    gate = torch.randint(low=0, 
            high=num_expert * torch.distributed.get_world_size(), 
23
            size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
24

Rick Ho's avatar
Rick Ho committed
25
    moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name)
26
    moe.train()
Rick Ho's avatar
Rick Ho committed
27

Rick Ho's avatar
Rick Ho committed
28
    o = moe(inp, gate)
29

Rick Ho's avatar
Rick Ho committed
30
31
    o = moe(inp, gate)
    o = moe(inp, gate)
Rick Ho's avatar
Rick Ho committed
32
33
34
35
    o = moe(inp, gate)

    n_runs = 16
    tott = 0.
36
    backt = 0.
Rick Ho's avatar
Rick Ho committed
37
38
    maxt = 0.
    sqtot = 0.
Rick Ho's avatar
Rick Ho committed
39
    for i in range(n_runs):
Rick Ho's avatar
Rick Ho committed
40
41
        gate = torch.randint(low=0, 
                high=num_expert * torch.distributed.get_world_size(), 
42
                size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
43
44
45
        ts = time.time()
        o = moe(inp, gate)
        te = time.time()
46
47
48
49
50
51
52

        loss = o.sum()

        bts = time.time()
        loss.backward()
        bte = time.time()

Rick Ho's avatar
Rick Ho committed
53
        tott += te - ts
Rick Ho's avatar
Rick Ho committed
54
55
        sqtot += (te - ts)**2
        maxt = max(maxt, te - ts)
56
        backt = bte - bts
Rick Ho's avatar
Rick Ho committed
57

Rick Ho's avatar
Rick Ho committed
58
    gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
59
    print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
Rick Ho's avatar
Rick Ho committed
60
        tott * 1e3 / n_runs, maxt * 1e3, 
61
62
        (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, 
        backt * 1e3 / n_runs, gflops))
Rick Ho's avatar
Rick Ho committed
63
64


Rick Ho's avatar
Rick Ho committed
65
66
67
68
69
def test_module(moe, linear, inp, gate):
    linear.zero_grad()
    moe.zero_grad()
    x = (linear(inp))
    output = moe(x, gate)
Rick Ho's avatar
Rick Ho committed
70
    print('ooutput', torch.distributed.get_rank(), output)
Rick Ho's avatar
Rick Ho committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    y = output.mean()
    y.backward()
    return output, moe.weight.grad, linear.weight.grad, linear.bias.grad


def test():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    batch_size = 4
    num_expert = 2
    in_feat = 6
    out_feat = 7

    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
86
87
88
89
    if world_size > 1:
        moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
    else:
        moe = MOELayer(num_expert, in_feat, out_feat).cuda()
Rick Ho's avatar
Rick Ho committed
90
91
92
93
94
95
96
97
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
    if world_size == 1:
        moe_raw.weight.data = moe.weight.data.clone()
    else:
        weight_array = [torch.empty_like(moe.weight.data).cpu() 
                for _ in range(world_size)]
        torch.distributed.all_gather(weight_array, moe.weight.data.cpu())
        moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda()
Rick Ho's avatar
Rick Ho committed
98
99
100

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, 
Rick Ho's avatar
Rick Ho committed
101
102
            high=num_expert * world_size, 
            size=(batch_size,), 
Rick Ho's avatar
Rick Ho committed
103
104
105
106
107
108
            requires_grad=False).int().cuda()
    # gate = torch.Tensor([0, 1, 0, 1]).int().cuda()

    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())

Rick Ho's avatar
Rick Ho committed
109
110
111
112
    if world_size == 1:
        names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
    else:
        names = ['Out']
Rick Ho's avatar
Rick Ho committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
        print('{} abs err {}'.format(name, err))


def test_dp():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    batch_size = 6
    num_expert = 4
    in_feat = 2
    out_feat = 3

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()

    print("data parallel of a nn.Linear model")
    linear = nn.Linear(in_feat, in_feat).cuda()
    linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
    output = linear_dp(inp)
    print("successful!")

    print("data parallel of our MoE model")
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
    for i in range(5):
        output = moe_dp(inp, gate)


Rick Ho's avatar
Rick Ho committed
142
if __name__ == '__main__':
Rick Ho's avatar
Rick Ho committed
143
    torch.distributed.init_process_group(backend='mpi')
Rick Ho's avatar
Rick Ho committed
144
    world_size = torch.distributed.get_world_size()
Rick Ho's avatar
Rick Ho committed
145
    test()
Rick Ho's avatar
Rick Ho committed
146
    # print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
Rick Ho's avatar
Rick Ho committed
147
    # perf()