moe.py 2.67 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12

class MOEFunction(Function):
    @staticmethod
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13
14
15
    def forward(ctx, inp, gate, weight):
        output = moe_cuda.forward(inp, gate, weight)
        variables = [inp, gate, weight]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16
17
18
19
20
21
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
22
        grad_inp, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
            grad_out.contiguous(), *ctx.saved_tensors)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
24
        return grad_inp, None, grad_weight
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
26
27
28
29


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
30
31
32
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
33
34
35
36
37
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
38
39
40
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
41

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
42
43
    def forward(self, inp, gate):
        return MOEFunction.apply(inp, gate, self.weight)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
44
45


Jiezhong Qiu's avatar
Jiezhong Qiu committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class MOELayer_einsum(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer_einsum, self).__init__()
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
            self.weight.data[i] = linear.weight.data
    
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
61
    def forward(self, inp, gate):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
62
        gate_long = gate.long()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
63
64
        batch_size = inp.size(0)
        x = inp.new_zeros((batch_size, self.out_feat))
Jiezhong Qiu's avatar
Jiezhong Qiu committed
65
        for i in range(batch_size):
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
66
            x[i] = self.weight[gate_long[i]] @ inp[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
67
68
        return x

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
69
70
71
72
batch_size = 1
num_expert = 1
in_feat = 3
out_feat = 3
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
73
74

moe = MOELayer(num_expert, in_feat, out_feat).cuda()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
75
76
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
77

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
78
79

inp = torch.rand(batch_size, in_feat).cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
80
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
81

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82
83
84
85
print(inp.type())
print(moe.weight.data.type())

print(inp)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
86
print(gate)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
87
output = moe(inp, gate)
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
88

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
89
print(inp)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
90
print(gate)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
91
output_einsum = moe_einsum(inp.clone(), gate.clone())
Jiezhong Qiu's avatar
Jiezhong Qiu committed
92
93
94

print(output)
print(output_einsum)
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
95

Jiezhong Qiu's avatar
Jiezhong Qiu committed
96
97
#y = output.mean()
#y.backward()