swipe_gate.py 1.56 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
r"""
Balanced gate using SWIPE algorithm
"""
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate

from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native


class SwipeGate(NaiveGate):
Rick Ho's avatar
Rick Ho committed
16
    def __init__(self, d_model, num_expert, world_size, top_k=2):
Rick Ho's avatar
Rick Ho committed
17
18
        super().__init__(d_model, num_expert, world_size, top_k)

Rick Ho's avatar
Rick Ho committed
19
    def swipe_once(self, idx, capacity, bias):
Rick Ho's avatar
Rick Ho committed
20
21
        with torch.no_grad():
            idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
Rick Ho's avatar
Rick Ho committed
22
                    self.num_expert, self.world_size, bias)
Rick Ho's avatar
Rick Ho committed
23
24
25
26
27
28
            idx_new = idx_new.to(idx.device)
        return idx_new, capacity


    def forward(self, inp):
        score = self.gate(inp)
Rick Ho's avatar
Rick Ho committed
29
        orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
Rick Ho's avatar
Rick Ho committed
30
31

        if not self.training:
Rick Ho's avatar
Rick Ho committed
32
33
            topk_val = F.softmax(orig_score, dim=-1)
            return orig_idx, topk_val
Rick Ho's avatar
Rick Ho committed
34
35
36

        capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
                dtype=torch.long)
Rick Ho's avatar
Rick Ho committed
37

Rick Ho's avatar
Rick Ho committed
38
        topk_idxs = []
Rick Ho's avatar
Rick Ho committed
39
40
        topk_vals = []
        idx_x = torch.arange(inp.shape[0], device=inp.device)
Rick Ho's avatar
Rick Ho committed
41
        for k in range(self.top_k):
Rick Ho's avatar
Rick Ho committed
42
43
44
            idx, capacity = self.swipe_once(orig_idx[:, k], capacity,
                    k % self.num_expert)
            topk_vals.append(score[idx_x, idx])
Rick Ho's avatar
Rick Ho committed
45
46
            topk_idxs.append(idx)
        topk_idx = torch.stack(topk_idxs).transpose(0, 1)
Rick Ho's avatar
Rick Ho committed
47
        topk_val = torch.stack(topk_vals).transpose(0, 1)
Rick Ho's avatar
Rick Ho committed
48
49
        topk_val = F.softmax(topk_val, dim=-1)
        return topk_idx, topk_val