test_numerical.py 2.58 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from moe import FMoE as MOELayer 
from moe import BruteForceMoE as MOELayer_raw
import torch
from torch import nn
import sys
import os


rank = None
world_size = None


def test_moe():
    def test_module(moe, linear, inp, gate):
        linear.zero_grad()
        moe.zero_grad()
        x = (linear(inp))
        output = moe(x, gate)
        y = output.mean()
        y.backward()
        return output, moe.weight.grad, linear.weight.grad, linear.bias.grad

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
    batch_size = 4
    num_expert = 2
    in_feat = 6
    out_feat = 7

    linear = nn.Linear(in_feat, in_feat).cuda()

    moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
    if world_size == 1:
        moe_raw.weight.data = moe.weight.data.clone()
    else:
        weight_array = [torch.empty_like(moe.weight.data)
                for _ in range(world_size)]
        torch.distributed.all_gather(weight_array, moe.weight.data)
        moe_raw.weight.data = torch.cat(weight_array, dim=0)

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, 
            high=num_expert * world_size, 
            size=(batch_size,), 
            requires_grad=False).int().cuda()
    # gate = torch.Tensor([0, 1, 0, 1]).int().cuda()

    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())

    names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
    if world_size > 1:
        ou, wg, lwg, lbg = raw_out
        torch.distributed.all_reduce(wg)
        wg = wg[rank * num_expert:(rank + 1)* num_expert]
        raw_out = ou, wg, lwg, lbg
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
        print('Rank {} {} abs err {}'.format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write('=========== moe out ==============\n')
            sys.stderr.write('{}\n'.format(mo)) 
            sys.stderr.write('=========== raw out ==============\n')
            sys.stderr.write('{}\n'.format(ro)) 
            return


if __name__ == '__main__':
    os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0')
    os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1')
    if int(os.environ['WORLD_SIZE']) > 1:
        torch.distributed.init_process_group(backend='nccl')
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1
    test_moe()