moe.py 2.9 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13
    def forward(ctx, inp, gate, weight):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
14
15
16
17
        out_feat, in_feat = weight.size()[1:]
        weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        output = moe_cuda.forward(inp, gate, weight_column_major)
        variables = [inp, gate, weight_column_major]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18
19
20
21
22
23
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
26
27
28
        out_feat, in_feat = grad_weight.size()[1:]
        grad_weight_row_major = grad_weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
        return grad_inp, None, grad_weight_row_major
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
29
30
31
32
33


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
34
35
36
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
37
38
39
40
41
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
43
44
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
45

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
46
47
    def forward(self, inp, gate):
        return MOEFunction.apply(inp, gate, self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
48
49


Jiezhong Qiu's avatar
Jiezhong Qiu committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class MOELayer_einsum(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer_einsum, self).__init__()
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
            self.weight.data[i] = linear.weight.data
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
65
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
66
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
67
68
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
        for i in range(batch_size):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
70
            x[i] = self.weight[gate_long[i]] @ inp[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
71
72
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
74
75
batch_size = 4
num_expert = 4
in_feat = 2
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
76
out_feat = 3
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
77
78

moe = MOELayer(num_expert, in_feat, out_feat).cuda()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
79
80
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
81

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82
83

inp = torch.rand(batch_size, in_feat).cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
84
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
85

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
86
87
output = moe(inp, gate)
output_einsum = moe_einsum(inp.clone(), gate.clone())
Jiezhong Qiu's avatar
Jiezhong Qiu committed
88
89
90

print(output)
print(output_einsum)
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
91

Jiezhong Qiu's avatar
Jiezhong Qiu committed
92
93
#y = output.mean()
#y.backward()