utils.py 783 Bytes
Newer Older
Rich Ho's avatar
Rich Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
r"""
Utilities that may be used in the gates
"""
import torch
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native


def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
    capacity = torch.ones(num_expert, dtype=torch.int32,
            device=topk_idx.device) * capacity

13
14
    pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
            require_pos=False)
Rich Ho's avatar
Rich Ho committed
15
16
17
    new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
            num_expert, world_size)
    if world_size > 1:
Rick Ho's avatar
Rick Ho committed
18
        new_lec, = fmoe_native.expert_exchange(new_gec, num_expert, world_size)
Rich Ho's avatar
Rich Ho committed
19
20
21
22
23
24
25
    else:
        new_lec = new_gec

    fmoe_native.prune_gate_by_capacity(topk_idx,
            new_lec.to(torch.int32), num_expert, world_size)

    return new_lec, new_gec