"vscode:/vscode.git/clone" did not exist on "5a290a5644d2eaa721d89e27a2b585a8ca95ce1c"
moe.py 3.27 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13
    def forward(ctx, inp, gate, weight):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14
15
16
17
        out_feat, in_feat = weight.size()[1:]
        weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight_column_major)
        variables = [inp, gate, weight_column_major]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
21
22
23
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
25
        print("grad_out", grad_out)
        print("input", ctx.saved_tensors[0])
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
28
        out_feat, in_feat = grad_weight.size()[1:]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
29
30
        print("grad_weight_column_major", grad_weight.flatten())
        grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
33
34
35
36


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
37
38
39
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
40
41
42
43
44
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
46
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
47
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50
    def forward(self, inp, gate):
        return MOEFunction.apply(inp, gate, self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51
52


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53
class MOELayer_raw(nn.Module):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
54
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
55
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
56
57
58
59
60
61
62
63
64
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
65
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
66
67
            self.weight.data[i] = linear.weight.data
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
68
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
70
71
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
72
        for i in range(batch_size):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
            x[i] = self.weight[gate_long[i]] @ inp[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
74
75
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
76

Rick Ho's avatar
Rick Ho committed
77
78
79
80
81
def test():
    batch_size = 4
    num_expert = 4
    in_feat = 2
    out_feat = 3
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82

Rick Ho's avatar
Rick Ho committed
83
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
84
85
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
    moe_raw.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
86

Rick Ho's avatar
Rick Ho committed
87
88
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
89

Rick Ho's avatar
Rick Ho committed
90
    output = moe(inp, gate)
91
    output_raw= moe_raw(inp.clone(), gate.clone())
Jiezhong Qiu's avatar
Jiezhong Qiu committed
92

Rick Ho's avatar
Rick Ho committed
93
    print(output)
94
    print(output_raw)
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
95

96
97
    y = output.mean()
    y.backward()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
98

99
100
    y_raw = output_raw.mean()
    y_raw.backward()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
101

102
103
    print(moe.weight.grad)
    print(moe_raw.weight.grad)
Rick Ho's avatar
Rick Ho committed
104
105
106
107


if __name__ == '__main__':
    test()