"official/legacy/bert/common_flags.py" did not exist on "e9f8dfa1564a6e3cb469999ea7a112b1f290efac"
gshard_gate.py 1.79 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
Rick Ho's avatar
Rick Ho committed
4
import math
Rick Ho's avatar
Rick Ho committed
5
6
7
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
Rich Ho's avatar
Rich Ho committed
8
from .utils import limit_by_capacity
Rick Ho's avatar
Rick Ho committed
9
import fmoe_cuda as fmoe_native
Rick Ho's avatar
Rick Ho committed
10
11
12


class GShardGate(NaiveGate):
Rich Ho's avatar
Rich Ho committed
13
    def __init__(self, d_model, num_expert, world_size,
14
15
            topk=2, capacity=(1.2, 2.4), random_routing=True):
        assert topk == 2, 'topk should be 2 in gshard'
Rick Ho's avatar
Rick Ho committed
16
17
        super().__init__(d_model, num_expert, world_size, top_k=2)
        self.capacity = capacity
zms1999's avatar
zms1999 committed
18
        self.random_routing = random_routing
Rick Ho's avatar
Rick Ho committed
19
20

    def forward(self, x):
Rich Ho's avatar
Rich Ho committed
21
22
        naive_outs = super().forward(x, return_all_scores=True)
        topk_idx, topk_val, gate_score = naive_outs
Rick Ho's avatar
Rick Ho committed
23
24

        S = gate_score.shape[0]
Fragile-azalea's avatar
Fragile-azalea committed
25
        top1_idx = topk_idx.view((-1, self.top_k))[:, 0]
Rick Ho's avatar
Rick Ho committed
26
        c_e = torch.scatter_add(
Rick Ho's avatar
Rick Ho committed
27
                torch.zeros(self.tot_expert, device=top1_idx.device),
Rick Ho's avatar
Rick Ho committed
28
29
30
31
32
33
34
35
                0,
                top1_idx,
                torch.ones_like(top1_idx, dtype=torch.float),
                ) / S
        m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
        loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
        self.set_loss(loss)

Rick Ho's avatar
Rick Ho committed
36
        cap_rate = self.capacity[0 if self.training else 1]
Rick Ho's avatar
Rick Ho committed
37
38
39
40
41
        capacity = math.ceil(cap_rate * x.shape[0]) // self.world_size
        capacity = torch.ones(self.num_expert * self.world_size,
                dtype=torch.int32, device=topk_idx.device) * capacity
        topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity,
                self.num_expert, self.world_size)
Rick Ho's avatar
Rick Ho committed
42

Rich Ho's avatar
Rich Ho committed
43
44
45
46
47
        if self.random_routing:
            rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
            mask = (2 * topk_val[:, 1] < rand_routing_prob)
            topk_idx[:, 1].masked_fill_(mask, -1)

Rick Ho's avatar
Rick Ho committed
48
        return topk_idx, topk_val