"docs/vscode:/vscode.git/clone" did not exist on "4b04998d3854830d482b7f5b2eda4ebb49e3dd19"
base_gate.py 685 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
r"""
Base gate with standard interface
"""
import torch.nn as nn


class BaseGate(nn.Module):
    def __init__(self, num_expert, world_size):
        super().__init__()
        self.world_size = world_size
        self.num_expert = num_expert
        self.tot_expert = world_size * num_expert
        self.loss = None

    def forward(self, x):
        raise NotImplementedError('Base gate cannot be directly used for fwd')

    def set_loss(self, loss):
        self.loss = loss

    def get_loss(self, clear=True):
        loss = self.loss
        if clear:
            self.loss = None
        return loss
Rick Ho's avatar
Rick Ho committed
26
27
28
29

    @property
    def has_loss(self):
        return self.loss is not None