moe.py 4.44 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
9
10
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda


class MOEFunction(Function):
    @staticmethod
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11
    def forward(ctx, inp, gate, weight):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
12
13
14
15
        out_feat, in_feat = weight.size()[1:]
        weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight_column_major)
        variables = [inp, gate, weight_column_major]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16
17
18
19
20
21
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
22
23
        # print("grad_out", grad_out)
        # print("input", ctx.saved_tensors[0])
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
        out_feat, in_feat = grad_weight.size()[1:]
27
        # print("grad_weight_column_major", grad_weight.flatten())
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
28
        grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
29
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
31
32
33
34


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
35
36
37
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
38
39
40
41
42
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
43
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
44
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
47
    def forward(self, inp, gate):
48
        return MOEFunction.apply(inp, gate.int(), self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51
class MOELayer_raw(nn.Module):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
52
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
54
55
56
57
58
59
60
61
62
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
63
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
64
65
            self.weight.data[i] = linear.weight.data
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
66
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
67
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
68
69
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
70
        for i in range(batch_size):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
71
            x[i] = self.weight[gate_long[i]] @ inp[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
72
73
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
74

Rick Ho's avatar
Rick Ho committed
75
def test():
76
77
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
Rick Ho's avatar
Rick Ho committed
78
79
80
81
    batch_size = 4
    num_expert = 4
    in_feat = 2
    out_feat = 3
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
83
84
    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
85
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
86
87
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
    moe_raw.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
88

Rick Ho's avatar
Rick Ho committed
89
90
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
91

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
92
93
94
95
96
    linear.zero_grad()
    moe.zero_grad()
    x = linear(inp)
    output = moe(x, gate)
    print("moe output", output)
97
98
    y = output.mean()
    y.backward()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
99
100
101
    print("moe.weight.grad", moe.weight.grad)
    print("linear.weight.grad", linear.weight.grad)
    print("linear.bias.grad", linear.bias.grad)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
102

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
103
104
105
106
107
108

    linear.zero_grad()
    moe.zero_grad()
    x = linear(inp.clone())
    output_raw= moe_raw(x, gate.clone())
    print("moe_raw output", output_raw)
109
110
    y_raw = output_raw.mean()
    y_raw.backward()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
111
112
113
    print("moe_raw.weight.grad", moe_raw.weight.grad)
    print("linear_raw.weight.grad", linear.weight.grad)
    print("linear_raw.bias.grad", linear.bias.grad)
Rick Ho's avatar
Rick Ho committed
114

Jiezhong Qiu's avatar
Jiezhong Qiu committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def test_dp():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    batch_size = 4
    num_expert = 4
    in_feat = 2
    out_feat = 3

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()

    print("data parallel of a nn.Linear model")
    linear = nn.Linear(in_feat, in_feat).cuda()
    moe_linear = torch.nn.DataParallel(linear, device_ids=[0, 1])
    output = moe_linear(inp)
    print("successful!")

    print("data parallel of our MoE model")
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_dp = torch.nn.DataParallel(moe, device_ids=[0, 1])
    output = moe_dp(inp, gate)



Rick Ho's avatar
Rick Ho committed
139
if __name__ == '__main__':
Jiezhong Qiu's avatar
Jiezhong Qiu committed
140
141
    # test()
    test_dp()