moe.py 3.71 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13
    def forward(ctx, inp, gate, weight):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14
15
16
17
        out_feat, in_feat = weight.size()[1:]
        weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight_column_major)
        variables = [inp, gate, weight_column_major]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
21
22
23
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
        print("grad_out", grad_out)
        print("input", ctx.saved_tensors[0])
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
28
        out_feat, in_feat = grad_weight.size()[1:]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
29
30
        print("grad_weight_column_major", grad_weight.flatten())
        grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
33
34
35
36


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
37
38
39
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
40
41
42
43
44
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
46
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
47
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50
    def forward(self, inp, gate):
        return MOEFunction.apply(inp, gate, self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51
52


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53
class MOELayer_raw(nn.Module):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
54
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
55
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
56
57
58
59
60
61
62
63
64
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
65
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
66
67
            self.weight.data[i] = linear.weight.data
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
68
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
70
71
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
72
        for i in range(batch_size):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
            x[i] = self.weight[gate_long[i]] @ inp[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
74
75
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
76

Rick Ho's avatar
Rick Ho committed
77
78
79
80
81
def test():
    batch_size = 4
    num_expert = 4
    in_feat = 2
    out_feat = 3
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
83
84
    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
85
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
86
87
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
    moe_raw.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
88

Rick Ho's avatar
Rick Ho committed
89
90
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
91

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
92
93
94
95
96
    linear.zero_grad()
    moe.zero_grad()
    x = linear(inp)
    output = moe(x, gate)
    print("moe output", output)
97
98
    y = output.mean()
    y.backward()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
99
100
101
    print("moe.weight.grad", moe.weight.grad)
    print("linear.weight.grad", linear.weight.grad)
    print("linear.bias.grad", linear.bias.grad)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
102

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
103
104
105
106
107
108

    linear.zero_grad()
    moe.zero_grad()
    x = linear(inp.clone())
    output_raw= moe_raw(x, gate.clone())
    print("moe_raw output", output_raw)
109
110
    y_raw = output_raw.mean()
    y_raw.backward()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
111
112
113
    print("moe_raw.weight.grad", moe_raw.weight.grad)
    print("linear_raw.weight.grad", linear.weight.grad)
    print("linear_raw.bias.grad", linear.bias.grad)
Rick Ho's avatar
Rick Ho committed
114
115
116

if __name__ == '__main__':
    test()