swipe_gate.py 1.43 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
r"""
Balanced gate using SWIPE algorithm
"""
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate

from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native


class SwipeGate(NaiveGate):
    requires_moe_group = True

    def __init__(self, d_model, num_expert, world_size, topk=2):
        super().__init__(d_model, num_expert, world_size, top_k)

    def swipe_once(self, idx, capacity):
        with torch.no_grad():
            idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
                    self.num_expert, self.world_size)
            idx_new = idx_new.to(idx.device)
        return idx_new, capacity


    def forward(self, inp):
        score = self.gate(inp)
        _, orig_idx = torch.topk(gate_score, k=self.top_k, dim=-1)

        if not self.training:
            topk_val = F.softmax(topk_val, dim=-1)
            return topk_idx, topk_val

        capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
                dtype=torch.long)
        
        topk_idxs = []
        for k in range(self.top_k):
            idx, capacity = self.swipe_once(orig_idx[:, k], capacity)
            topk_idxs.append(idx)
        topk_idx = torch.stack(topk_idxs).transpose(0, 1)
        topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)
        topk_val = F.softmax(topk_val, dim=-1)
        return topk_idx, topk_val