moe.py 1.71 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
import math
from torch import nn
import torch
Rick Ho's avatar
Rick Ho committed
4
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
5

Rick Ho's avatar
Rick Ho committed
6
from fmoe.layers import FMoELinear, _fmoe_full_forward
Rick Ho's avatar
Rick Ho committed
7
8
9
10


class FMoE(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
Rick Ho's avatar
Rick Ho committed
11
            world_size=1):
Rick Ho's avatar
Rick Ho committed
12
        super(FMoE, self).__init__()
Rick Ho's avatar
Rick Ho committed
13
14
15
16
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.world_size = world_size
Rick Ho's avatar
Rick Ho committed
17
18
        self.linear = FMoELinear(num_expert, in_feat, out_feat)
        self.weight = self.linear.weight
Rick Ho's avatar
Rick Ho committed
19
20
21
        self.reset_parameters()

    def reset_parameters(self):
Rick Ho's avatar
Rick Ho committed
22
        self.linear.reset_parameters()
Rick Ho's avatar
Rick Ho committed
23
24

    def forward(self, inp, gate):
Rick Ho's avatar
Rick Ho committed
25
26
        return _fmoe_full_forward(inp, gate, [self.linear], None,
                self.num_expert, self.world_size)
Rick Ho's avatar
Rick Ho committed
27
28
29
30
31


class BruteForceMoE(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, 
            world_size=0):
Rick Ho's avatar
Rick Ho committed
32
        super(BruteForceMoE, self).__init__()
Rick Ho's avatar
Rick Ho committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert * world_size, out_feat, in_feat))
        self.reset_parameters()


    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, 
                    out_features=self.out_feat)
            # print(linear.weight.shape)
            self.weight.data[i] = linear.weight.data
    
    def forward(self, inp, gate):
        gate_long = gate.long()
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
        for i in range(batch_size):
            x[i] = inp[i] @ self.weight[gate_long[i]].t()
        return x