moe.py 2.79 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import math
from torch import nn
import torch


Rick Ho's avatar
Rick Ho committed
6
class BruteForceMoELinear(nn.Module):
7
8
9
10
11
12
13
14
15
    def __init__(
        self,
        activation,
        num_expert=32,
        d_model=1024,
        d_hidden=2048,
        world_size=1,
        top_k=2,
    ):
Rick Ho's avatar
Rick Ho committed
16
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
17
        self.num_expert = num_expert
Sengxian's avatar
Sengxian committed
18
19
20
        self.d_model = d_model
        self.activation = activation
        self.weight_htoh4 = nn.Parameter(
21
            torch.Tensor(num_expert * world_size, d_hidden, d_model)
Sengxian's avatar
Sengxian committed
22
        )
Jiezhong Qiu's avatar
Jiezhong Qiu committed
23
24
25
        self.bias_htoh4 = nn.Parameter(
            torch.Tensor(num_expert * world_size, d_hidden)
        )
Sengxian's avatar
Sengxian committed
26
        self.weight_h4toh = nn.Parameter(
27
            torch.Tensor(num_expert * world_size, d_model, d_hidden)
Sengxian's avatar
Sengxian committed
28
        )
Jiezhong Qiu's avatar
Jiezhong Qiu committed
29
30
31
        self.bias_h4toh = nn.Parameter(
            torch.Tensor(num_expert * world_size, d_model)
        )
Sengxian's avatar
Sengxian committed
32
33
34
35
        self.top_k = top_k

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
Rick Ho's avatar
Rick Ho committed
36
        batch_size = inp.size(0)
Rick Ho's avatar
Rick Ho committed
37
        o = torch.empty(batch_size, self.d_model, dtype=inp.dtype,
Rick Ho's avatar
Rick Ho committed
38
                device=inp.device)
Rick Ho's avatar
Rick Ho committed
39
40
        for i in range(self.weight_htoh4.shape[0]):
            idx = (gate_idx == i)
Rick Ho's avatar
Rick Ho committed
41
            x = inp[idx]
Rick Ho's avatar
Rick Ho committed
42
            x = x @ self.weight_htoh4[i].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
43
            x = x + self.bias_htoh4[i]
Rick Ho's avatar
Rick Ho committed
44
45
            x = self.activation(x)
            x = x @ self.weight_h4toh[i].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
46
            x = x + self.bias_h4toh[i]
Rick Ho's avatar
Rick Ho committed
47
            o[idx] = x
Rick Ho's avatar
Rick Ho committed
48
49
        x = torch.bmm(gate_score, o.view(-1, self.top_k, 
            self.d_model)).reshape(-1, self.d_model)
Rick Ho's avatar
Rick Ho committed
50
        return x
Rick Ho's avatar
Rick Ho committed
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

class BruteForceMoE(nn.Module):
    def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
        super(BruteForceMoE, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.top_k = top_k
        self.experts = [expert(d_model) for _ in range(num_expert * world_size)]

    def forward(self, inp, gate_idx, gate_score):
        gate_long = gate_idx.long()
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.d_model))
        for i in range(batch_size):
            x[i] = self.experts[gate_long[i]](inp[i])
        x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
            -1, self.d_model
        )
        return x


class NaiveExpert(nn.Module):
    def __init__(self, d_model):
        super(NaiveExpert, self).__init__()
        self.linear = nn.Linear(d_model, d_model).cuda()

    def forward(self, x):
        return self.linear(x)


class LinearExpert(nn.Module):
    def __init__(self, d_model):
        super(LinearExpert, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
        ).cuda()

    def forward(self, x):
        return self.model(x)