gates.py 6.8 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
Sengxian's avatar
Sengxian committed
4
"""
Rick Ho's avatar
Rick Ho committed
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
from torch.distributions.normal import Normal
Rick Ho's avatar
Rick Ho committed
9
10


Rick Ho's avatar
Rick Ho committed
11
class ZeroGate(nn.Module):
Sengxian's avatar
Sengxian committed
12
    r"""
Rick Ho's avatar
Rick Ho committed
13
    Guide all input samples to gate 0.
Sengxian's avatar
Sengxian committed
14
15
    """

16
    def __init__(self, _1, num_expert, _3, top_k=2):
Rick Ho's avatar
Rick Ho committed
17
        super().__init__()
18
        self.num_expert = num_expert
Rick Ho's avatar
Rick Ho committed
19
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
20
21

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
22
        r"""
Rick Ho's avatar
Rick Ho committed
23
        All output to expert 1
Sengxian's avatar
Sengxian committed
24
25
26
27
        """
        idx = torch.zeros(
            inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
        )
28
29
30
31
32
33
        gate_score = (
            torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
        )
        gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device)
        gate_score_all[:, 0] = 1
        return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
Rick Ho's avatar
Rick Ho committed
34
35


Rick Ho's avatar
Rick Ho committed
36
class NaiveGate(nn.Module):
Sengxian's avatar
Sengxian committed
37
    r"""
Rick Ho's avatar
Rick Ho committed
38
39
40
41
42
43
    A naive gate implementation that defines the standard behavior of the gate
    which determines which experts the tokens are going to.
    Both the indecies and the score, or confidence, are output to the parent
    module.
    The load-balance strategies are also designed to be implemented within the
    `Gate` module.
Sengxian's avatar
Sengxian committed
44
45
    """

Rick Ho's avatar
Rick Ho committed
46
47
48
49
50
51
    def __init__(self, d_model, num_expert, world_size, top_k=2):
        super().__init__()
        self.gate = nn.Linear(d_model, num_expert * world_size)
        self.top_k = top_k

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
52
        r"""
Rick Ho's avatar
Rick Ho committed
53
54
        The naive implementation simply calculates the top-k of a linear layer's
        output.
Sengxian's avatar
Sengxian committed
55
        """
Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
61
62
63
64
65
        gate = self.gate(inp)
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=False
        )  # [.. x top_k]
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

        # (BxL) x 1 x top_k
        gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
        gate_top_k_idx = gate_top_k_idx.view(-1)  # (BxLxtop_k)

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        return gate_top_k_idx, gate_score, gate


class NoisyGate(nn.Module):
    def __init__(self, d_model, num_expert, world_size, top_k=2):
        super().__init__()
        self.num_expert = num_expert * world_size
        self.w_gate = nn.Parameter(
            torch.zeros(d_model, num_expert * world_size), requires_grad=True
        )
        self.w_noise = nn.Parameter(
            torch.zeros(d_model, num_expert * world_size), requires_grad=True
        )
        self.top_k = top_k
        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)

        self.noise_epsilon = 1e-2

    def _gates_to_load(self, gates):
        """Compute the true load per expert, given the gates.
        The load is the number of examples for which the corresponding gate is >0.
        Args:
        gates: a `Tensor` of shape [batch_size, n]
        Returns:
        a float32 `Tensor` of shape [n]
        """
        return (gates > 0).sum(0)

    def _prob_in_top_k(
        self, clean_values, noisy_values, noise_stddev, noisy_top_values
    ):
        """Helper function to NoisyTopKGating.
        Computes the probability that value is in top k, given different random noise.
        This gives us a way of backpropagating from a loss that balances the number
        of times each expert is in the top k experts per example.
        In the case of no noise, pass in None for noise_stddev, and the result will
        not be differentiable.
        Args:
        clean_values: a `Tensor` of shape [batch, n].
        noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
          normally distributed noise with standard deviation noise_stddev.
        noise_stddev: a `Tensor` of shape [batch, n], or None
        noisy_top_values: a `Tensor` of shape [batch, m].
           "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
        Returns:
        a `Tensor` of shape [batch, n].
        """

        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()
        threshold_positions_if_in = (
            torch.arange(batch, device=clean_values.device) * m + self.top_k
        )
        threshold_if_in = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
        )
        is_in = torch.gt(noisy_values, threshold_if_in)
        threshold_positions_if_out = threshold_positions_if_in - 1
        threshold_if_out = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
        )
        # is each value currently in the top k.
        normal = Normal(
            torch.tensor([0.0], device=clean_values.device),
            torch.tensor([1.0], device=clean_values.device),
        )
        prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
        prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def cv_squared(self, x):
        """The squared coefficient of variation of a sample.
        Useful as a loss to encourage a positive distribution to be more uniform.
        Epsilons added for numerical stability.
        Returns 0 for an empty Tensor.
        Args:
        x: a `Tensor`.
        Returns:
        a `Scalar`.
        """
        eps = 1e-10
        # if only num_expert = 1
        if x.shape[0] == 1:
            return torch.Tensor([0])
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def forward(self, inp):
        clean_logits = inp @ self.w_gate
        raw_noise_stddev = inp @ self.w_noise
        noise_stddev = (
            self.softplus(raw_noise_stddev) + self.noise_epsilon
        ) * self.training
        noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
        logits = noisy_logits

        # calculate topk + 1 that will be needed for the noisy gates
        top_logits, top_indices = logits.topk(
            min(self.top_k + 1, self.num_expert), dim=1
        )
        top_k_logits = top_logits[:, : self.top_k]
        top_k_indices = top_indices[:, : self.top_k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)

        if self.top_k < self.num_expert:
            load = (
                self._prob_in_top_k(
                    clean_logits, noisy_logits, noise_stddev, top_logits
                )
            ).sum(0)
        else:
            load = self._gates_to_load(gates)

        importance = gates.sum(0)
        loss = self.cv_squared(importance) + self.cv_squared(load)

        return (
            top_k_indices.contiguous().view(-1),
            top_k_gates.contiguous().unsqueeze(1),
            loss,
        )