moe.py 3.02 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13
    def forward(ctx, inp, gate, weight):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14
15
16
17
        out_feat, in_feat = weight.size()[1:]
        weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight_column_major)
        variables = [inp, gate, weight_column_major]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
21
22
23
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
28
        out_feat, in_feat = grad_weight.size()[1:]
        grad_weight_row_major = grad_weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
29
30
31
32
33


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
34
35
36
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
37
38
39
40
41
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
43
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
44
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
45

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47
    def forward(self, inp, gate):
        return MOEFunction.apply(inp, gate, self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
49


Jiezhong Qiu's avatar
Jiezhong Qiu committed
50
51
52
53
54
55
56
57
58
59
60
61
class MOELayer_einsum(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer_einsum, self).__init__()
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
62
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
63
64
            self.weight.data[i] = linear.weight.data
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
65
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
66
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67
68
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
        for i in range(batch_size):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
70
            x[i] = self.weight[gate_long[i]] @ inp[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
71
72
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73

Rick Ho's avatar
Rick Ho committed
74
75
76
77
78
def test():
    batch_size = 4
    num_expert = 4
    in_feat = 2
    out_feat = 3
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
79

Rick Ho's avatar
Rick Ho committed
80
81
82
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
    moe_einsum.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
83

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
84

Rick Ho's avatar
Rick Ho committed
85
86
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
87

Rick Ho's avatar
Rick Ho committed
88
89
    output = moe(inp, gate)
    output_einsum = moe_einsum(inp.clone(), gate.clone())
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
90

Rick Ho's avatar
Rick Ho committed
91
92
93
94
95
96
97
98
99
    print(output)
    print(output_einsum)

    #y = output.mean()
    #y.backward()


if __name__ == '__main__':
    test()