gates.py 1.84 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
Sengxian's avatar
Sengxian committed
4
"""
Rick Ho's avatar
Rick Ho committed
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F


Rick Ho's avatar
Rick Ho committed
10
class ZeroGate(nn.Module):
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
    Guide all input samples to gate 0.
Sengxian's avatar
Sengxian committed
13
14
    """

Rick Ho's avatar
Rick Ho committed
15
    def __init__(self, _1, _2, _3, top_k=2):
Rick Ho's avatar
Rick Ho committed
16
        super().__init__()
Rick Ho's avatar
Rick Ho committed
17
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
18
19

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
20
        r"""
Rick Ho's avatar
Rick Ho committed
21
        All output to expert 1
Sengxian's avatar
Sengxian committed
22
23
24
25
        """
        idx = torch.zeros(
            inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
        )
Rick Ho's avatar
Rick Ho committed
26
27
        score = torch.ones(inp.shape[0] * self.top_k,
                device=inp.device) / self.top_k
Rick Ho's avatar
Rick Ho committed
28
        return idx, score.reshape(-1, 1, self.top_k)
Rick Ho's avatar
Rick Ho committed
29
30


Rick Ho's avatar
Rick Ho committed
31
class NaiveGate(nn.Module):
Sengxian's avatar
Sengxian committed
32
    r"""
Rick Ho's avatar
Rick Ho committed
33
34
35
36
37
38
    A naive gate implementation that defines the standard behavior of the gate
    which determines which experts the tokens are going to.
    Both the indecies and the score, or confidence, are output to the parent
    module.
    The load-balance strategies are also designed to be implemented within the
    `Gate` module.
Sengxian's avatar
Sengxian committed
39
40
    """

Rick Ho's avatar
Rick Ho committed
41
42
43
44
45
46
    def __init__(self, d_model, num_expert, world_size, top_k=2):
        super().__init__()
        self.gate = nn.Linear(d_model, num_expert * world_size)
        self.top_k = top_k

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
47
        r"""
Rick Ho's avatar
Rick Ho committed
48
49
        The naive implementation simply calculates the top-k of a linear layer's
        output.
Sengxian's avatar
Sengxian committed
50
        """
Rick Ho's avatar
Rick Ho committed
51
52
53
54
55
56
57
58
59
60
61
        gate = self.gate(inp)
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=False
        )  # [.. x top_k]
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

        # (BxL) x 1 x top_k
        gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
        gate_top_k_idx = gate_top_k_idx.view(-1)  # (BxLxtop_k)

        return gate_top_k_idx, gate_score