gshard_gate.py 1.84 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
Rick Ho's avatar
Rick Ho committed
4
import math
Rick Ho's avatar
Rick Ho committed
5
6
7
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
Rick Ho's avatar
Rick Ho committed
8
9
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
Rick Ho's avatar
Rick Ho committed
10
11
12
13
14
15
16
17


class GShardGate(NaiveGate):
    def __init__(self, d_model, num_expert, world_size, capacity=(1.2, 2.4)):
        super().__init__(d_model, num_expert, world_size, top_k=2)
        self.capacity = capacity

    def forward(self, x):
Rick Ho's avatar
Rick Ho committed
18
        topk_idx, gate_score = super().forward(x)
Rick Ho's avatar
Rick Ho committed
19
20
21
22
23

        S = gate_score.shape[0]
        top_k = topk_idx.shape[0] // gate_score.shape[0]
        top1_idx = topk_idx.view((-1, top_k))[:, 0]
        c_e = torch.scatter_add(
Rick Ho's avatar
Rick Ho committed
24
                torch.zeros(self.num_expert, device=top1_idx.device),
Rick Ho's avatar
Rick Ho committed
25
26
27
28
29
30
31
32
                0,
                top1_idx,
                torch.ones_like(top1_idx, dtype=torch.float),
                ) / S
        m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
        loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
        self.set_loss(loss)

Rick Ho's avatar
Rick Ho committed
33
34
35
36
        cap_rate = self.capacity[0 if self.training else 1]
        capacity = torch.ones(self.num_expert, dtype=torch.int32)
        capacity *= math.ceil(cap_rate * x.shape[0])

Rick Ho's avatar
Rick Ho committed
37
38
39
40
        print(topk_idx)
        pos, lec, gec = count_by_gate(gate_score, self.num_expert,
                self.world_size)
        print(topk_idx)
Rick Ho's avatar
Rick Ho committed
41
        new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
Rick Ho's avatar
Rick Ho committed
42
                self.num_expert, self.world_size)
Rick Ho's avatar
Rick Ho committed
43
        print(topk_idx)
Rick Ho's avatar
Rick Ho committed
44
45
46
47
48
        if self.world_size > 1:
            new_lec = fmoe_native.expert_exchange(new_gec, 
                    self.num_expert, self.world_size)
        else:
            new_lec = new_gec
Rick Ho's avatar
Rick Ho committed
49
        print(topk_idx)
Rick Ho's avatar
Rick Ho committed
50
51
52

        fmoe_native.prune_gate_by_capacity(topk_idx,
                new_lec.to(torch.int32), self.num_expert, self.world_size)
Rick Ho's avatar
Rick Ho committed
53
54

        return topk_idx, topk_val