test_gates.py 746 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import os
import torch
import torch.distributed as dist
from fmoe.gates import GShardGate


def test_gshard_gate(d_model, batch_size, n_expert):
    gate = GShardGate(d_model, n_expert, dist.get_world_size()).cuda()
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
    print('rank {} idx {}'.format(dist.get_rank(), topk_idx))
    print('rank {} val {}'.format(dist.get_rank(), topk_val))


if __name__ == '__main__':
    os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
    os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
    os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
    torch.distributed.init_process_group(backend="nccl")
    test_gshard_gate(4096, 1024, 4)