r""" Base gate with standard interface """ import torch.nn as nn class BaseGate(nn.Module): def __init__(self, num_expert, world_size): super().__init__() self.world_size = world_size self.num_expert = num_expert self.tot_expert = world_size * num_expert self.loss = None def forward(self, x): raise NotImplementedError('Base gate cannot be directly used for fwd') def set_loss(self, loss): self.loss = loss def get_loss(self, clear=True): loss = self.loss if clear: self.loss = None return loss