moe.py 1.08 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import math
from torch import nn
import torch


Sengxian's avatar
Sengxian committed
6
7
8
class BruteForceMoELinear(nn.Module):
    def __init__(self, activation, num_expert=32, d_model=1024, world_size=1, top_k=2):
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
9
        self.num_expert = num_expert
Sengxian's avatar
Sengxian committed
10
11
12
13
14
15
16
17
18
19
20
21
        self.d_model = d_model
        self.activation = activation
        self.weight_htoh4 = nn.Parameter(
            torch.Tensor(num_expert * world_size, d_model * 4, d_model)
        )
        self.weight_h4toh = nn.Parameter(
            torch.Tensor(num_expert * world_size, d_model, d_model * 4)
        )
        self.top_k = top_k

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
Rick Ho's avatar
Rick Ho committed
22
        batch_size = inp.size(0)
Sengxian's avatar
Sengxian committed
23
        x = inp.new_zeros((batch_size, self.d_model))
Rick Ho's avatar
Rick Ho committed
24
        for i in range(batch_size):
Sengxian's avatar
Sengxian committed
25
26
27
28
29
30
            t = inp[i] @ self.weight_htoh4[gate_long[i]].t()
            t = self.activation(t)
            x[i] = t @ self.weight_h4toh[gate_long[i]].t()
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
Rick Ho's avatar
Rick Ho committed
31
        return x