import math from torch import nn import torch class BruteForceMoELinear(nn.Module): def __init__( self, activation, num_expert=32, d_model=1024, d_hidden=2048, world_size=1, top_k=2, ): super(BruteForceMoELinear, self).__init__() self.num_expert = num_expert self.d_model = d_model self.activation = activation self.weight_htoh4 = nn.Parameter( torch.Tensor(num_expert * world_size, d_hidden, d_model) ) self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden)) self.weight_h4toh = nn.Parameter( torch.Tensor(num_expert * world_size, d_model, d_hidden) ) self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model)) self.top_k = top_k def forward(self, inp, gate_idx, gate_score): inp = inp.repeat_interleave(repeats=self.top_k, dim=0) gate_long = gate_idx.long().view(-1) batch_size = inp.size(0) o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device) for i in range(self.weight_htoh4.shape[0]): idx = gate_long == i x = inp[idx] x = x @ self.weight_htoh4[i].t() x = x + self.bias_htoh4[i] x = self.activation(x) x = x @ self.weight_h4toh[i].t() x = x + self.bias_h4toh[i] o[idx] = x gate_score = gate_score.unsqueeze(1) x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape( -1, self.d_model ) return x class BruteForceMoE(nn.Module): def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2): super(BruteForceMoE, self).__init__() self.num_expert = num_expert self.d_model = d_model self.top_k = top_k if type(expert) is list: self.experts = [e(d_model) for e in expert] self.num_expert = num_expert = len(expert) else: self.experts = [expert(d_model) for _ in range(num_expert * world_size)] def forward(self, inp, gate_idx, gate_score): inp = inp.repeat_interleave(repeats=self.top_k, dim=0) gate_long = gate_idx.long().view(-1) batch_size = inp.size(0) x = inp.new_zeros((batch_size, self.d_model)) for i in range(batch_size): x[i] = self.experts[gate_long[i]](inp[i]) gate_score = gate_score.unsqueeze(1) x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape( -1, self.d_model ) return x class NaiveExpert(nn.Module): def __init__(self, d_model): super(NaiveExpert, self).__init__() self.linear = nn.Linear(d_model, d_model) def forward(self, x, fec=None): return self.linear(x) class LinearExpert(nn.Module): def __init__(self, d_model): super(LinearExpert, self).__init__() self.model = nn.Sequential( nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model), ) def forward(self, x, fec=None): return self.model(x)