test_zero.py 1.4 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from fmoe.layers import _fmoe_general_global_forward
from fmoe import FMoETransformerMLP


class ConstantGate(torch.nn.Module):
    def __init__(self, d_model, num_expert, world_size, top_k=1):
        super().__init__()
        self.top_k = top_k

    def forward(self, inp):
        idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64,
                device=inp.device)
        score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
        return idx, score, None


def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
    inp = torch.rand(batch_size, d_hidden).cuda()
    gate = torch.zeros(batch_size, dtype=torch.int64).cuda()
    x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert,
            world_size)


def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
    inp = torch.rand(batch_size, d_hidden).cuda()
    model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
            gate=ConstantGate).cuda()
    oup = model(inp)


if __name__ == '__main__':
    torch.distributed.init_process_group(backend="nccl")
    torch.cuda.set_device(torch.distributed.get_rank())
    # test_zero_fwd(world_size=torch.distributed.get_world_size())
    test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024,
            world_size=torch.distributed.get_world_size())
    print('done')