moe.py 1.73 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
import math
from torch import nn
import torch

Rick Ho's avatar
Rick Ho committed
5
from moe_function import moe
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
6
7
8


class MOELayer(nn.Module):
Rick Ho's avatar
Rick Ho committed
9
10
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
            world_size=None):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
13
14
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
15
        self.world_size = world_size
Rick Ho's avatar
Rick Ho committed
16
17
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
21
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
22
23
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
    def forward(self, inp, gate):
Rick Ho's avatar
Rick Ho committed
26
        return moe(inp, gate.int(), self.weight, self.world_size)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
29
class MOELayer_raw(nn.Module):
Rick Ho's avatar
Rick Ho committed
30
31
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, 
            world_size=0):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
33
34
35
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
36
        self.weight = nn.Parameter(
Rick Ho's avatar
Rick Ho committed
37
            torch.Tensor(num_expert * world_size, out_feat, in_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
38
39
        self.reset_parameters()

Rick Ho's avatar
Rick Ho committed
40

Jiezhong Qiu's avatar
Jiezhong Qiu committed
41
42
    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
43
44
            linear = nn.Linear(in_features=self.in_feat, 
                    out_features=self.out_feat)
Rick Ho's avatar
Rick Ho committed
45
            # print(linear.weight.shape)
Rick Ho's avatar
Rick Ho committed
46
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
Jiezhong Qiu committed
47
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
49
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
50
51
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
52
        for i in range(batch_size):
Rick Ho's avatar
Rick Ho committed
53
            x[i] = inp[i] @ self.weight[gate_long[i]].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
54
        return x