moe.py 1.09 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
import math
from torch import nn
import torch
Rick Ho's avatar
Rick Ho committed
4
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
5
6


Rick Ho's avatar
Rick Ho committed
7
class BruteForceMoELinear(nn.Module):
Rick Ho's avatar
Rick Ho committed
8
9
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, 
            world_size=0):
Rick Ho's avatar
Rick Ho committed
10
        super(BruteForceMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert * world_size, out_feat, in_feat))
        self.reset_parameters()


    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, 
                    out_features=self.out_feat)
            self.weight.data[i] = linear.weight.data
    
    def forward(self, inp, gate):
        gate_long = gate.long()
        batch_size = inp.size(0)
Rick Ho's avatar
Rick Ho committed
28
29
30
31
32
33
34
35
        o = torch.empty(batch_size, self.out_feat, dtype=inp.dtype,
                device=inp.device)
        for i in range(self.num_expert):
            idx = (gate == i)
            x = inp[idx]
            x = x @ self.weight[i].t()
            o[idx] = x
        return o