"src/rewrite_gemm.cpp" did not exist on "9164116ad361286db786bad11c346bb70f8be3c3"
naive_gate.py 1.37 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
r"""
Naive gate
"""
from .base_gate import BaseGate

import torch
import torch.nn as nn
import torch.nn.functional as F


class NaiveGate(BaseGate):
    r"""
    A naive gate implementation that defines the standard behavior of the gate
    which determines which experts the tokens are going to.
Rich Ho's avatar
Rich Ho committed
15
    Both the indicies and the score, or confidence, are output to the parent
Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
21
    module.
    The load-balance strategies are also designed to be implemented within the
    `Gate` module.
    """

    def __init__(self, d_model, num_expert, world_size, top_k=2):
Rick Ho's avatar
Rick Ho committed
22
        super().__init__(num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
23
24
25
        self.gate = nn.Linear(d_model, self.tot_expert)
        self.top_k = top_k

Rich Ho's avatar
Rich Ho committed
26
    def forward(self, inp, return_all_scores=False):
Rick Ho's avatar
Rick Ho committed
27
28
29
30
31
32
33
34
35
36
37
        r"""
        The naive implementation simply calculates the top-k of a linear layer's
        output.
        """
        gate = self.gate(inp)
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=False
        )  # [.. x top_k]
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

        # (BxL) x 1 x top_k
Rick Ho's avatar
Rick Ho committed
38
        gate_score = F.softmax(gate_top_k_val, dim=-1)
Rick Ho's avatar
Rick Ho committed
39

zms1999's avatar
zms1999 committed
40
41
42
        # dummy loss
        self.set_loss(torch.zeros(1, requires_grad=True).cuda())

Rich Ho's avatar
Rich Ho committed
43
        if return_all_scores:
Rick Ho's avatar
Rick Ho committed
44
45
            return gate_top_k_idx, gate_score, gate
        return gate_top_k_idx, gate_score