moe.py 2.5 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import math
from torch import nn
import torch


Rick Ho's avatar
Rick Ho committed
6
class BruteForceMoELinear(nn.Module):
7
8
9
10
11
12
13
14
15
    def __init__(
        self,
        activation,
        num_expert=32,
        d_model=1024,
        d_hidden=2048,
        world_size=1,
        top_k=2,
    ):
Rick Ho's avatar
Rick Ho committed
16
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
17
        self.num_expert = num_expert
Sengxian's avatar
Sengxian committed
18
19
20
        self.d_model = d_model
        self.activation = activation
        self.weight_htoh4 = nn.Parameter(
21
            torch.Tensor(num_expert * world_size, d_hidden, d_model)
Sengxian's avatar
Sengxian committed
22
23
        )
        self.weight_h4toh = nn.Parameter(
24
            torch.Tensor(num_expert * world_size, d_model, d_hidden)
Sengxian's avatar
Sengxian committed
25
26
27
28
29
        )
        self.top_k = top_k

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
Rick Ho's avatar
Rick Ho committed
30
        batch_size = inp.size(0)
Rick Ho's avatar
Rick Ho committed
31
        o = torch.empty(batch_size, self.d_model, dtype=inp.dtype,
Rick Ho's avatar
Rick Ho committed
32
                device=inp.device)
Rick Ho's avatar
Rick Ho committed
33
34
        for i in range(self.weight_htoh4.shape[0]):
            idx = (gate_idx == i)
Rick Ho's avatar
Rick Ho committed
35
            x = inp[idx]
Rick Ho's avatar
Rick Ho committed
36
37
38
            x = x @ self.weight_htoh4[i].t()
            x = self.activation(x)
            x = x @ self.weight_h4toh[i].t()
Rick Ho's avatar
Rick Ho committed
39
            o[idx] = x
Rick Ho's avatar
Rick Ho committed
40
41
        x = torch.bmm(gate_score, o.view(-1, self.top_k, 
            self.d_model)).reshape(-1, self.d_model)
Rick Ho's avatar
Rick Ho committed
42
        return x
Rick Ho's avatar
Rick Ho committed
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

class BruteForceMoE(nn.Module):
    def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
        super(BruteForceMoE, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.top_k = top_k
        self.experts = [expert(d_model) for _ in range(num_expert * world_size)]

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.d_model))
        for i in range(batch_size):
            x[i] = self.experts[gate_long[i]](inp[i])
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
        return x


class NaiveExpert(nn.Module):
    def __init__(self, d_model):
        super(NaiveExpert, self).__init__()
        self.linear = nn.Linear(d_model, d_model).cuda()

    def forward(self, x):
        return self.linear(x)


class LinearExpert(nn.Module):
    def __init__(self, d_model):
        super(LinearExpert, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
        ).cuda()

    def forward(self, x):
        return self.model(x)