moe.py 2.97 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import math
from torch import nn
import torch


Rick Ho's avatar
Rick Ho committed
6
class BruteForceMoELinear(nn.Module):
7
8
9
10
11
12
13
14
15
    def __init__(
        self,
        activation,
        num_expert=32,
        d_model=1024,
        d_hidden=2048,
        world_size=1,
        top_k=2,
    ):
Rick Ho's avatar
Rick Ho committed
16
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
17
        self.num_expert = num_expert
Sengxian's avatar
Sengxian committed
18
19
20
        self.d_model = d_model
        self.activation = activation
        self.weight_htoh4 = nn.Parameter(
21
            torch.Tensor(num_expert * world_size, d_hidden, d_model)
Sengxian's avatar
Sengxian committed
22
        )
Sengxian's avatar
Sengxian committed
23
        self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden))
Sengxian's avatar
Sengxian committed
24
        self.weight_h4toh = nn.Parameter(
25
            torch.Tensor(num_expert * world_size, d_model, d_hidden)
Sengxian's avatar
Sengxian committed
26
        )
Sengxian's avatar
Sengxian committed
27
        self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model))
Sengxian's avatar
Sengxian committed
28
29
30
        self.top_k = top_k

    def forward(self, inp, gate_idx, gate_score):
Rick Ho's avatar
Rick Ho committed
31
32
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
        gate_long = gate_idx.long().view(-1)
Rick Ho's avatar
Rick Ho committed
33
        batch_size = inp.size(0)
Sengxian's avatar
Sengxian committed
34
        o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
Rick Ho's avatar
Rick Ho committed
35
        for i in range(self.weight_htoh4.shape[0]):
Rick Ho's avatar
Rick Ho committed
36
            idx = gate_long == i
Rick Ho's avatar
Rick Ho committed
37
            x = inp[idx]
Rick Ho's avatar
Rick Ho committed
38
            x = x @ self.weight_htoh4[i].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
39
            x = x + self.bias_htoh4[i]
Rick Ho's avatar
Rick Ho committed
40
41
            x = self.activation(x)
            x = x @ self.weight_h4toh[i].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
            x = x + self.bias_h4toh[i]
Rick Ho's avatar
Rick Ho committed
43
            o[idx] = x
Rich Ho's avatar
Rich Ho committed
44
        gate_score = gate_score.unsqueeze(1)
Rick Ho's avatar
Rick Ho committed
45

Sengxian's avatar
Sengxian committed
46
47
48
        x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
Rick Ho's avatar
Rick Ho committed
49
        return x
Rick Ho's avatar
Rick Ho committed
50

51
52
53
54
55
56
57
58
59
60

class BruteForceMoE(nn.Module):
    def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
        super(BruteForceMoE, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.top_k = top_k
        self.experts = [expert(d_model) for _ in range(num_expert * world_size)]

    def forward(self, inp, gate_idx, gate_score):
Rick Ho's avatar
Rick Ho committed
61
62
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
        gate_long = gate_idx.long().view(-1)
63
64
65
66
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.d_model))
        for i in range(batch_size):
            x[i] = self.experts[gate_long[i]](inp[i])
Rich Ho's avatar
Rich Ho committed
67
        gate_score = gate_score.unsqueeze(1)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
        return x


class NaiveExpert(nn.Module):
    def __init__(self, d_model):
        super(NaiveExpert, self).__init__()
        self.linear = nn.Linear(d_model, d_model).cuda()

    def forward(self, x):
        return self.linear(x)


class LinearExpert(nn.Module):
    def __init__(self, d_model):
        super(LinearExpert, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
        ).cuda()

    def forward(self, x):
        return self.model(x)