moe.py 4.72 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
    def forward(ctx, inp, gate, weight1, weight2):
        # out_feat, in_feat = weight.size()[1:]
        # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight1, weight2)
        variables = [inp, gate, weight1, weight2]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
21
22
23
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
24
25
        # print("grad_out", grad_out)
        # print("input", ctx.saved_tensors[0])
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
28
        out_feat, in_feat = grad_weight.size()[1:]
Rick Ho's avatar
Rick Ho committed
29
        # print("grad_weight_column_major", grad_weight.flatten())
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
30
        grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
31
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
32
33
34


class MOELayer(nn.Module):
Rick Ho's avatar
Rick Ho committed
35
    def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
36
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
37
38
        self.num_expert = num_expert
        self.in_feat = in_feat
Rick Ho's avatar
Rick Ho committed
39
        self.hidden_feat = hidden_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
40
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
41
42
43
44
        self.weight1 = nn.Parameter(
            torch.Tensor(num_expert, hidden_feat, in_feat))
        self.weight2 = nn.Parameter(
            torch.Tensor(num_expert, out_feat, hidden_feat))
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
45
46
47
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
48
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
49
50
51
52
            linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
            self.weight1.data[i] = linear.weight.data
            linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
            self.weight2.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
53

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
54
    def forward(self, inp, gate):
Rick Ho's avatar
Rick Ho committed
55
        return MOEFunction.apply(inp, gate, self.weight1, self.weight2)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
56
57


Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
58
class MOELayer_raw(nn.Module):
Rick Ho's avatar
Rick Ho committed
59
    def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
60
        super(MOELayer_raw, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
61
62
        self.num_expert = num_expert
        self.in_feat = in_feat
Rick Ho's avatar
Rick Ho committed
63
        self.hidden_feat = hidden_feat
Jiezhong Qiu's avatar
Jiezhong Qiu committed
64
        self.out_feat = out_feat
Rick Ho's avatar
Rick Ho committed
65
66
67
68
        self.weight1 = nn.Parameter(
            torch.Tensor(num_expert, hidden_feat, in_feat))
        self.weight2 = nn.Parameter(
            torch.Tensor(num_expert, out_feat, hidden_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
70
71
72
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
73
            linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
Rick Ho's avatar
Rick Ho committed
74
            # print(linear.weight.shape)
Rick Ho's avatar
Rick Ho committed
75
            self.weight1.data[i] = (linear.weight.data)
Rick Ho's avatar
Rick Ho committed
76
            linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
Rick Ho's avatar
Rick Ho committed
77
            self.weight2.data[i] = (linear.weight.data)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
78
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
79
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
80
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
81
82
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Rick Ho's avatar
Rick Ho committed
83
        # print(self.weight2)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
84
        for i in range(batch_size):
Rick Ho's avatar
Rick Ho committed
85
            hid = inp[i] @ self.weight1[gate_long[i]].t()
Rick Ho's avatar
Rick Ho committed
86
            # print(hid)
Rick Ho's avatar
Rick Ho committed
87
            x[i] = hid @ self.weight2[gate_long[i]].t()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
88
89
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
90

Rick Ho's avatar
Rick Ho committed
91
92
93
def test_module(moe, linear, inp, gate):
    linear.zero_grad()
    moe.zero_grad()
Rick Ho's avatar
Rick Ho committed
94
    x = (linear(inp))
Rick Ho's avatar
Rick Ho committed
95
    output = moe(x, gate)
Rick Ho's avatar
Rick Ho committed
96
97
98
    # print(output)
    if torch.distributed.get_rank() == 1:
        print(output)
Rick Ho's avatar
Rick Ho committed
99
    return output
Rick Ho's avatar
Rick Ho committed
100
101
102
103
104
    y = output.mean()
    y.backward()
    return output, moe.weight.grad, linear.weight.grad, linear.bias.grad


Rick Ho's avatar
Rick Ho committed
105
106
def test():
    batch_size = 4
Rick Ho's avatar
Rick Ho committed
107
108
109
110
    num_expert = 2
    in_feat = 6
    hidden_feat = 12
    out_feat = 7
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
111

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
112
113
    linear = nn.Linear(in_feat, in_feat).cuda()

Rick Ho's avatar
Rick Ho committed
114
115
116
117
    moe = MOELayer(num_expert, in_feat, hidden_feat, out_feat).cuda()
    moe_raw = MOELayer_raw(num_expert, in_feat, hidden_feat, out_feat).cuda()
    moe_raw.weight1.data = moe.weight1.data.clone()
    moe_raw.weight2.data = moe.weight2.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
118

Rick Ho's avatar
Rick Ho committed
119
120
    inp = torch.rand(batch_size, in_feat).cuda()
    gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Rick Ho's avatar
Rick Ho committed
121
    gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
122

Rick Ho's avatar
Rick Ho committed
123
124
    moe_out = test_module(moe, linear, inp.clone(), gate.clone())
    raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
125

Rick Ho's avatar
Rick Ho committed
126
    names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
Rick Ho's avatar
Rick Ho committed
127
    names = ['Out']
Rick Ho's avatar
Rick Ho committed
128
129
130
    for name, mo, ro in zip(names, moe_out, raw_out):
        err = (mo - ro).abs().sum()
        print('{} abs err {}'.format(name, err))
Rick Ho's avatar
Rick Ho committed
131
132

if __name__ == '__main__':
Rick Ho's avatar
Rick Ho committed
133
    torch.distributed.init_process_group(backend='mpi')
Rick Ho's avatar
Rick Ho committed
134
    test()