moe.py 4.56 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
    def forward(ctx, inp, gate, weight1, weight2):
        # out_feat, in_feat = weight.size()[1:]
        # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight1, weight2)
        variables = [inp, gate, weight1, weight2]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
21
22
23
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
24
25
        # print("grad_out", grad_out)
        # print("input", ctx.saved_tensors[0])
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
28
        out_feat, in_feat = grad_weight.size()[1:]
Rick Ho's avatar
Rick Ho committed
29
        # print("grad_weight_column_major", grad_weight.flatten())
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
        grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
33
34


class MOELayer(nn.Module):
Rick Ho's avatar
Rick Ho committed
35
    def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
36
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
37
38
        self.num_expert = num_expert
        self.in_feat = in_feat
Rick Ho's avatar
Rick Ho committed
39
        self.hidden_feat = hidden_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
40
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
41
42
43
44
        self.weight1 = nn.Parameter(
            torch.Tensor(num_expert, hidden_feat, in_feat))
        self.weight2 = nn.Parameter(
            torch.Tensor(num_expert, out_feat, hidden_feat))
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
45
46
47
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
48
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
49
50
51
52
            linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
            self.weight1.data[i] = linear.weight.data
            linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
            self.weight2.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
54
    def forward(self, inp, gate):
Rick Ho's avatar
Rick Ho committed
55
        return MOEFunction.apply(inp, gate, self.weight1, self.weight2)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
56
57


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
58
class MOELayer_raw(nn.Module):
Rick Ho's avatar
Rick Ho committed
59
    def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
60
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
61
62
        self.num_expert = num_expert
        self.in_feat = in_feat
Rick Ho's avatar
Rick Ho committed
63
        self.hidden_feat = hidden_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
64
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
65
66
67
68
        self.weight1 = nn.Parameter(
            torch.Tensor(num_expert, hidden_feat, in_feat))
        self.weight2 = nn.Parameter(
            torch.Tensor(num_expert, out_feat, hidden_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
70
71
72
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
73
74
75
76
77
            linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
            print(linear.weight.shape)
            self.weight1.data[i] = linear.weight.data
            linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
            self.weight2.data[i] = linear.weight.data
Jiezhong Qiu's avatar
Jiezhong Qiu committed
78
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
79
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
80
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
81
82
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Rick Ho's avatar
Rick Ho committed
83
        print(self.weight2)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
84
        for i in range(batch_size):
Rick Ho's avatar
Rick Ho committed
85
86
87
            hid = inp[i] @ self.weight1[gate_long[i]].t()
            print(hid)
            x[i] = hid @ self.weight2[gate_long[i]].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
88
89
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
90

Rick Ho's avatar
Rick Ho committed
91
92
93
94
95
96
def test_module(moe, linear, inp, gate):
    linear.zero_grad()
    moe.zero_grad()
    x = linear(inp)
    output = moe(x, gate)
    print(output)
Rick Ho's avatar
Rick Ho committed
97
98
    return output
    print(output)
Rick Ho's avatar
Rick Ho committed
99
100
101
102
103
    y = output.mean()
    y.backward()
    return output, moe.weight.grad, linear.weight.grad, linear.bias.grad


Rick Ho's avatar
Rick Ho committed
104
105
def test():
    batch_size = 4
Rick Ho's avatar
Rick Ho committed
106
107
108
109
    num_expert = 2
    in_feat = 6
    hidden_feat = 12
    out_feat = 7
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
110

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
111
112
    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
113
114
115
116
    moe = MOELayer(num_expert, in_feat, hidden_feat, out_feat).cuda()
    moe_raw = MOELayer_raw(num_expert, in_feat, hidden_feat, out_feat).cuda()
    moe_raw.weight1.data = moe.weight1.data.clone()
    moe_raw.weight2.data = moe.weight2.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
117

Rick Ho's avatar
Rick Ho committed
118
119
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
120

Rick Ho's avatar
Rick Ho committed
121
122
    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
123

Rick Ho's avatar
Rick Ho committed
124
    names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
Rick Ho's avatar
Rick Ho committed
125
    names = ['Out']
Rick Ho's avatar
Rick Ho committed
126
127
128
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
        print('{} abs err {}'.format(name, err))
Rick Ho's avatar
Rick Ho committed
129
130
131

if __name__ == '__main__':
    test()