gshard_gate.py 1.66 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
Rick Ho's avatar
Rick Ho committed
4
import math
Rick Ho's avatar
Rick Ho committed
5
6
7
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
Rich Ho's avatar
Rich Ho committed
8
from .utils import limit_by_capacity
Rick Ho's avatar
Rick Ho committed
9
10
11


class GShardGate(NaiveGate):
Rich Ho's avatar
Rich Ho committed
12
    def __init__(self, d_model, num_expert, world_size,
13
14
            topk=2, capacity=(1.2, 2.4), random_routing=True):
        assert topk == 2, 'topk should be 2 in gshard'
Rick Ho's avatar
Rick Ho committed
15
16
        super().__init__(d_model, num_expert, world_size, top_k=2)
        self.capacity = capacity
Rich Ho's avatar
Rich Ho committed
17
        self.random_routing = True
Rick Ho's avatar
Rick Ho committed
18
19

    def forward(self, x):
Rich Ho's avatar
Rich Ho committed
20
21
        naive_outs = super().forward(x, return_all_scores=True)
        topk_idx, topk_val, gate_score = naive_outs
Rick Ho's avatar
Rick Ho committed
22
23
24
25
26

        S = gate_score.shape[0]
        top_k = topk_idx.shape[0] // gate_score.shape[0]
        top1_idx = topk_idx.view((-1, top_k))[:, 0]
        c_e = torch.scatter_add(
Rick Ho's avatar
Rick Ho committed
27
                torch.zeros(self.tot_expert, device=top1_idx.device),
Rick Ho's avatar
Rick Ho committed
28
29
30
31
32
33
34
35
                0,
                top1_idx,
                torch.ones_like(top1_idx, dtype=torch.float),
                ) / S
        m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
        loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
        self.set_loss(loss)

Rick Ho's avatar
Rick Ho committed
36
        cap_rate = self.capacity[0 if self.training else 1]
Rich Ho's avatar
Rich Ho committed
37
        capacity = math.ceil(cap_rate * x.shape[0])
38
39
        _new_lec, _new_gec, topk_idx = limit_by_capacity(
                topk_idx, self.num_expert, self.world_size, capacity)
Rick Ho's avatar
Rick Ho committed
40

Rich Ho's avatar
Rich Ho committed
41
42
43
44
45
        if self.random_routing:
            rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
            mask = (2 * topk_val[:, 1] < rand_routing_prob)
            topk_idx[:, 1].masked_fill_(mask, -1)

Rick Ho's avatar
Rick Ho committed
46
        return topk_idx, topk_val