moe.py 4.49 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
9
10
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda


class MOEFunction(Function):
    @staticmethod
Rick Ho's avatar
Rick Ho committed
11
    def forward(ctx, inp, gate, weight):
Rick Ho's avatar
Rick Ho committed
12
13
        # out_feat, in_feat = weight.size()[1:]
        # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
        expert_count, pos = moe_cuda.expert_count(weight, gate)
        input_buf, = moe_cuda.local_scatter(inp, pos)
        output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
        output = moe_cuda.local_gather(output_buf, pos)

Rick Ho's avatar
Rick Ho committed
19
        variables = [input_buf, gate, weight, expert_count, pos]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
20
21
22
23
24
25
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
26
27
28
29
30
31
32
33
        input_buf, gate, weight, expert_count, pos = ctx.saved_tensors

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        grad_inp_buf, grad_weight = moe_cuda.backward(
                grad_out_buf, input_buf, weight, expert_count)
        grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)

        return grad_inp, None, grad_weight
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
34
35
36


class MOELayer(nn.Module):
Rick Ho's avatar
Rick Ho committed
37
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
38
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
39
40
41
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
42
43
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44
45
46
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
47
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
48
49
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
50

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
51
    def forward(self, inp, gate):
52
        return MOEFunction.apply(inp, gate.int(), self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53
54


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
55
class MOELayer_raw(nn.Module):
Rick Ho's avatar
Rick Ho committed
56
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
57
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
58
59
60
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
61
62
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
63
64
65
66
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
67
68
69
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
            # print(linear.weight.shape)
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
Jiezhong Qiu committed
70
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
71
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
72
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
74
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
75
        for i in range(batch_size):
Rick Ho's avatar
Rick Ho committed
76
            x[i] = inp[i] @ self.weight[gate_long[i]].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
77
78
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
79

Rick Ho's avatar
Rick Ho committed
80
81
82
83
84
85
86
87
88
89
def test_module(moe, linear, inp, gate):
    linear.zero_grad()
    moe.zero_grad()
    x = linear(inp)
    output = moe(x, gate)
    y = output.mean()
    y.backward()
    return output, moe.weight.grad, linear.weight.grad, linear.bias.grad


Rick Ho's avatar
Rick Ho committed
90
def test():
91
92
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
Rick Ho's avatar
Rick Ho committed
93
    batch_size = 4
Rick Ho's avatar
Rick Ho committed
94
95
96
    num_expert = 2
    in_feat = 6
    out_feat = 7
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
97

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
98
99
    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
100
101
102
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
    moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
    moe_raw.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
103

Rick Ho's avatar
Rick Ho committed
104
105
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
106

Rick Ho's avatar
Rick Ho committed
107
108
    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
109

Rick Ho's avatar
Rick Ho committed
110
111
112
113
    names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
        print('{} abs err {}'.format(name, err))
Rick Ho's avatar
Rick Ho committed
114

Jiezhong Qiu's avatar
Jiezhong Qiu committed
115
116
117
def test_dp():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
118
    batch_size = 6
Jiezhong Qiu's avatar
Jiezhong Qiu committed
119
120
121
122
123
124
125
126
127
    num_expert = 4
    in_feat = 2
    out_feat = 3

    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()

    print("data parallel of a nn.Linear model")
    linear = nn.Linear(in_feat, in_feat).cuda()
128
129
    linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
    output = linear_dp(inp)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
130
131
132
133
    print("successful!")

    print("data parallel of our MoE model")
    moe = MOELayer(num_expert, in_feat, out_feat).cuda()
134
135
136
    moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
    for i in range(5):
        output = moe_dp(inp, gate)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
137
138
139



Rick Ho's avatar
Rick Ho committed
140
if __name__ == '__main__':
Rick Ho's avatar
Rick Ho committed
141
142
    test()
    # test_dp()