moe.py 2.8 KB
Newer Older
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
1
2
3
4
5
6
7
8
import math
from torch import nn
from torch.autograd import Function
import torch

import moe_cuda

torch.manual_seed(42)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
9
torch.cuda.manual_seed(42)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
10
11
12
13
14
15
16
17
18
19
20
21
22

class MOEFunction(Function):
    @staticmethod
    def forward(ctx, input, gate, weight):
        output = moe_cuda.forward(input, gate, weight)
        variables = [input, gate, weight]
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        grad_input, grad_weight = moe_cuda.backward(
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
24
            grad_out.contiguous(), *ctx.saved_tensors)
        return grad_input, None, grad_weight
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
26
27
28
29


class MOELayer(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer, self).__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
30
31
32
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
33
34
35
36
37
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
38
39
40
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
            self.weight.data[i] = linear.weight.data
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
41
42
43
44
45

    def forward(self, input, gate):
        return MOEFunction.apply(input, gate, self.weight)


Jiezhong Qiu's avatar
Jiezhong Qiu committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class MOELayer_einsum(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
        super(MOELayer_einsum, self).__init__()
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
            self.weight.data[i] = linear.weight.data
    
    def forward(self, input, gate):
        gate_long = gate.long()
        #W = self.weight[gate_long] # [batch_size x out_feat x in_feat]
        #x = torch.einsum('id,ihd->ih', (input, W)) # [batch_size x out_feat]
        #return x
        batch_size = input.size(0)
        x = input.new_zeros((batch_size, self.out_feat))
        for i in range(batch_size):
            x[i] = self.weight[gate_long[i]] @ input[i]
        return x

batch_size = 2
num_expert = 2
in_feat = 2
out_feat = 4
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
76
77

moe = MOELayer(num_expert, in_feat, out_feat).cuda()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
78
79
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
80
81

input = torch.rand(batch_size, in_feat).cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
82
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
83

Jiezhong Qiu's avatar
Jiezhong Qiu committed
84
85
print(input)
print(gate)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
86
output = moe(input, gate)
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
87

Jiezhong Qiu's avatar
Jiezhong Qiu committed
88
89
90
91
92
93
print(input)
print(gate)
output_einsum = moe_einsum(input, gate)

print(output)
print(output_einsum)
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
94

Jiezhong Qiu's avatar
Jiezhong Qiu committed
95
96
#y = output.mean()
#y.backward()