test_dp.py 792 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from moe import FMoE as MOELayer 
from moe import BruteForceMoE as MOELayer_raw
import torch
from torch import nn
import sys
import os


n_devices = int(os.environ.get('N_GPUS', '2'))


def test_dp():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    batch_size = 6
    num_expert = 4
    in_feat = 2
    out_feat = 3

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), 
            requires_grad=False).cuda()

    print("data parallel of our MoE model")
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
    for i in range(5):
        output = moe_dp(inp, gate)
    print('Successful')


if __name__ == '__main__':
    test_dp()