switch_gate.py 1.89 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
Rich Ho's avatar
Rich Ho committed
4
import math
Rick Ho's avatar
Rick Ho committed
5
import torch
Rich Ho's avatar
Rich Ho committed
6
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
7
8
import torch.nn.functional as F
from .naive_gate import NaiveGate
Rich Ho's avatar
Rich Ho committed
9
10
from .utils import limit_by_capacity

Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

class SwitchGate(NaiveGate):
    r"""
    A switch gate implementation
    """

    def __init__(self, d_model, num_expert, world_size,
            switch_eps=.1, capacity=(1.2, 2.4)):
        super().__init__(d_model, num_expert, world_size, top_k=1)
        self.switch_eps = switch_eps
        self.capacity = capacity

    def forward(self, inp):
        r"""
        The switch firstly conduct softmax and then calculates the top-1
        """
Rich Ho's avatar
Rich Ho committed
27
28
        score = self.gate(inp)

Rick Ho's avatar
Rick Ho committed
29
30
        if self.training:
            # random uniform number from [1-eps, 1+eps]
Rich Ho's avatar
Rich Ho committed
31
            noise = torch.rand_like(score)
Rick Ho's avatar
Rick Ho committed
32
            noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
Rich Ho's avatar
Rich Ho committed
33
            score += noise
Rick Ho's avatar
Rick Ho committed
34
35

        # fp32 softmax for numerical stability
Rich Ho's avatar
Rich Ho committed
36
        score = F.softmax(score.float(), dim=-1)
Rick Ho's avatar
Rick Ho committed
37

Rich Ho's avatar
Rich Ho committed
38
39
        top1_score, top1_idx = torch.topk(
            score, k=1, dim=-1, largest=True
Rick Ho's avatar
Rick Ho committed
40
        )  # [.. x top_k]
Rich Ho's avatar
Rich Ho committed
41
42
        top1_score = top1_score.to(dtype=inp.dtype)
        top1_score = top1_score.to(dtype=inp.dtype)
Rick Ho's avatar
Rick Ho committed
43

Rich Ho's avatar
Rich Ho committed
44
45
46
        cap_rate = self.capacity[0 if self.training else 1]
        capacity = math.ceil(cap_rate * inp.shape[0])
        limit_by_capacity(top1_idx, self.num_expert, self.world_size, capacity)
Rick Ho's avatar
Rick Ho committed
47

Rich Ho's avatar
Rich Ho committed
48
        valid_idx = top1_idx[top1_idx > -1]
Rick Ho's avatar
Rick Ho committed
49
        fraction_expert = torch.scatter_add(
Rich Ho's avatar
Rich Ho committed
50
                torch.zeros(self.tot_expert, device=valid_idx.device),
Rick Ho's avatar
Rick Ho committed
51
                0,
Rich Ho's avatar
Rich Ho committed
52
53
54
55
56
57
58
                valid_idx,
                torch.ones_like(valid_idx, dtype=torch.float),
            ) / valid_idx.numel()
        prob_expert = score.sum(dim=0) / valid_idx.numel()
        loss = (fraction_expert * prob_expert).sum() * self.tot_expert
        self.set_loss(loss)
        return top1_idx, top1_score