moe.py 2.74 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import math
from torch import nn
import torch


Rick Ho's avatar
Rick Ho committed
6
class BruteForceMoELinear(nn.Module):
7
8
9
10
11
12
13
14
15
    def __init__(
        self,
        activation,
        num_expert=32,
        d_model=1024,
        d_hidden=2048,
        world_size=1,
        top_k=2,
    ):
Rick Ho's avatar
Rick Ho committed
16
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
17
        self.num_expert = num_expert
Sengxian's avatar
Sengxian committed
18
19
20
        self.d_model = d_model
        self.activation = activation
        self.weight_htoh4 = nn.Parameter(
21
            torch.Tensor(num_expert * world_size, d_hidden, d_model)
Sengxian's avatar
Sengxian committed
22
        )
Sengxian's avatar
Sengxian committed
23
        self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden))
Sengxian's avatar
Sengxian committed
24
        self.weight_h4toh = nn.Parameter(
25
            torch.Tensor(num_expert * world_size, d_model, d_hidden)
Sengxian's avatar
Sengxian committed
26
        )
Sengxian's avatar
Sengxian committed
27
        self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model))
Sengxian's avatar
Sengxian committed
28
29
30
31
        self.top_k = top_k

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
Rick Ho's avatar
Rick Ho committed
32
        batch_size = inp.size(0)
Sengxian's avatar
Sengxian committed
33
        o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
Rick Ho's avatar
Rick Ho committed
34
        for i in range(self.weight_htoh4.shape[0]):
Sengxian's avatar
Sengxian committed
35
            idx = gate_idx == i
Rick Ho's avatar
Rick Ho committed
36
            x = inp[idx]
Rick Ho's avatar
Rick Ho committed
37
            x = x @ self.weight_htoh4[i].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
38
            x = x + self.bias_htoh4[i]
Rick Ho's avatar
Rick Ho committed
39
40
            x = self.activation(x)
            x = x @ self.weight_h4toh[i].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
41
            x = x + self.bias_h4toh[i]
Rick Ho's avatar
Rick Ho committed
42
            o[idx] = x
Sengxian's avatar
Sengxian committed
43
44
45
        x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
Rick Ho's avatar
Rick Ho committed
46
        return x
Rick Ho's avatar
Rick Ho committed
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

class BruteForceMoE(nn.Module):
    def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
        super(BruteForceMoE, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.top_k = top_k
        self.experts = [expert(d_model) for _ in range(num_expert * world_size)]

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.d_model))
        for i in range(batch_size):
            x[i] = self.experts[gate_long[i]](inp[i])
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
        return x


class NaiveExpert(nn.Module):
    def __init__(self, d_model):
        super(NaiveExpert, self).__init__()
        self.linear = nn.Linear(d_model, d_model).cuda()

    def forward(self, x):
        return self.linear(x)


class LinearExpert(nn.Module):
    def __init__(self, d_model):
        super(LinearExpert, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
        ).cuda()

    def forward(self, x):
        return self.model(x)