moe.py 2.4 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import math
from torch import nn
import torch


Sengxian's avatar
Sengxian committed
6
class BruteForceMoELinear(nn.Module):
7
8
9
10
11
12
13
14
15
    def __init__(
        self,
        activation,
        num_expert=32,
        d_model=1024,
        d_hidden=2048,
        world_size=1,
        top_k=2,
    ):
Sengxian's avatar
Sengxian committed
16
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
17
        self.num_expert = num_expert
Sengxian's avatar
Sengxian committed
18
19
20
        self.d_model = d_model
        self.activation = activation
        self.weight_htoh4 = nn.Parameter(
21
            torch.Tensor(num_expert * world_size, d_hidden, d_model)
Sengxian's avatar
Sengxian committed
22
23
        )
        self.weight_h4toh = nn.Parameter(
24
            torch.Tensor(num_expert * world_size, d_model, d_hidden)
Sengxian's avatar
Sengxian committed
25
26
27
28
29
        )
        self.top_k = top_k

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
Rick Ho's avatar
Rick Ho committed
30
        batch_size = inp.size(0)
Sengxian's avatar
Sengxian committed
31
        x = inp.new_zeros((batch_size, self.d_model))
Rick Ho's avatar
Rick Ho committed
32
        for i in range(batch_size):
Sengxian's avatar
Sengxian committed
33
34
35
36
37
38
            t = inp[i] @ self.weight_htoh4[gate_long[i]].t()
            t = self.activation(t)
            x[i] = t @ self.weight_h4toh[gate_long[i]].t()
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
Rick Ho's avatar
Rick Ho committed
39
        return x
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


class BruteForceMoE(nn.Module):
    def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
        super(BruteForceMoE, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.top_k = top_k
        self.experts = [expert(d_model) for _ in range(num_expert * world_size)]

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.d_model))
        for i in range(batch_size):
            x[i] = self.experts[gate_long[i]](inp[i])
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
        return x


class NaiveExpert(nn.Module):
    def __init__(self, d_model):
        super(NaiveExpert, self).__init__()
        self.linear = nn.Linear(d_model, d_model).cuda()

    def forward(self, x):
        return self.linear(x)


class LinearExpert(nn.Module):
    def __init__(self, d_model):
        super(LinearExpert, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
        ).cuda()

    def forward(self, x):
        return self.model(x)