moe.py 3.73 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
import math
from torch import nn
import torch
Rick Ho's avatar
Rick Ho committed
4
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
5
6
7
8
9
10
11

from .moe_function import moe


class FMoE(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
            world_size=None):
Rick Ho's avatar
Rick Ho committed
12
        super(FMoE, self).__init__()
Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.world_size = world_size
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
            self.weight.data[i] = linear.weight.data

    def forward(self, inp, gate):
        return moe(inp, gate.int(), self.weight, self.world_size)


Rick Ho's avatar
Rick Ho committed
30
class FFFN(nn.Module):
Rick Ho's avatar
Rick Ho committed
31
32
33
    def __init__(self, num_expert=32, d_model=1024, d_hidden=4096, 
            world_size=None, activation=torch.nn.functional.gelu,
            top_k=2, pre_lnorm=False):
Rick Ho's avatar
Rick Ho committed
34
        super(FFFN, self).__init__()
Rick Ho's avatar
Rick Ho committed
35
36
37
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.world_size = world_size
Rick Ho's avatar
Rick Ho committed
38
        self.activation = activation
Rick Ho's avatar
Rick Ho committed
39
40
41
42
        self.top_k = top_k
        self.pre_lnorm = pre_lnorm

        self.htoh4 = FMoE(num_expert, d_model, d_hidden,
Rick Ho's avatar
Rick Ho committed
43
                world_size=world_size)
Rick Ho's avatar
Rick Ho committed
44
45
46
47
48
49
        self.h4toh = FMoE(num_expert, d_hidden, d_model, 
                world_size=world_size)
        self.gate = nn.Linear(d_model, num_expert)
        self.layer_norm = nn.LayerNorm(d_model)
        self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
                dtype=torch.float32)) 
Rick Ho's avatar
Rick Ho committed
50

Rick Ho's avatar
Rick Ho committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def forward(self, inp):
        # import pdb; pdb.set_trace()
        residual = inp
        if self.pre_lnorm:
            inp = self.layer_norm(inp)

        gate = self.gate(inp)
        gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
                largest=True, sorted=False) # [.. x top_k]
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

        # (BxL) x 1 x top_k 
        gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) 
        gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)

        inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k, 
                dim=0) # (BxLxtop_k) x d_model
        x = self.htoh4(inp, gate_top_k_idx)
Rick Ho's avatar
Rick Ho committed
69
        x = self.activation(x)
Rick Ho's avatar
Rick Ho committed
70
71
72
73
74
75
76
77
78
79
        x = self.h4toh(x, gate_top_k_idx)

        core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model 
        core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
        core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
        output = core_out + residual

        if not self.pre_lnorm:
            output = self.layer_norm(output)
        return output, self.bias
Rick Ho's avatar
Rick Ho committed
80
81


Rick Ho's avatar
Rick Ho committed
82
83
84
class BruteForceMoE(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, 
            world_size=0):
Rick Ho's avatar
Rick Ho committed
85
        super(BruteForceMoE, self).__init__()
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert * world_size, out_feat, in_feat))
        self.reset_parameters()


    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, 
                    out_features=self.out_feat)
            # print(linear.weight.shape)
            self.weight.data[i] = linear.weight.data
    
    def forward(self, inp, gate):
        gate_long = gate.long()
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
        for i in range(batch_size):
            x[i] = inp[i] @ self.weight[gate_long[i]].t()
        return x