"tests/unit/utils/test_get_optim_files.py" did not exist on "1b2721adcd96656bb1f27d1f2f60947567b2d505"
layers.py 3.74 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
from .fmoe_functions import *
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
3
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
4
5
6
7


class FMoELinear(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Rick Ho's avatar
Rick Ho committed
8
        super(FMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = nn.Parameter(
            torch.Tensor(num_expert, out_feat, in_feat))
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
            self.weight.data[i] = linear.weight.data

    def forward(self, inp, fwd_expert_count):
        return MOELinear.apply(inp, self.weight, fwd_expert_count)


Rick Ho's avatar
Rick Ho committed
25
26
class FMoENaiveGate(nn.Module):
    def __init__(self, d_model, num_expert, world_size, top_k=2):
Rick Ho's avatar
Rick Ho committed
27
28
        super(FMoENaiveGate, self).__init__()
        self.gate = nn.Linear(d_model, num_expert * world_size)
Rick Ho's avatar
Rick Ho committed
29
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    def forward(self, inp):
        gate = self.gate(inp)
        gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1,
                largest=True, sorted=False) # [.. x top_k]
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

        # (BxL) x 1 x top_k 
        gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) 
        gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)

        return gate_top_k_idx, gate_score


def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
    (pos, local_expert_count, global_expert_count, fwd_expert_count, 
            fwd_batch_size) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(inp, pos, local_expert_count, global_expert_count, 
            fwd_batch_size, world_size)
    for i, l in enumerate(linears):
        if i:
            x = activation(x)
52
        x = l(x, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
53
54
55
56
57
    x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
            inp.shape[0], world_size)
    return x


Rick Ho's avatar
Rick Ho committed
58
class FMoETransformerMLP(nn.Module):
Rick Ho's avatar
Rick Ho committed
59
60
61
62
63
64
65
66
67
68
    def __init__(self, num_expert=32, d_model=1024, d_hidden=4096, 
            world_size=None, activation=torch.nn.functional.gelu,
            top_k=2, pre_lnorm=False):
        super(FMoETransformerMLP, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.world_size = world_size
        self.activation = activation
        self.pre_lnorm = pre_lnorm
Rick Ho's avatar
Rick Ho committed
69
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
70
71
72
73

        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model) 

Rick Ho's avatar
Rick Ho committed
74
        self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
75
76
77
78
79
80
81
82
83
84
85
86

        self.layer_norm = nn.LayerNorm(d_model)
        self.bias = torch.nn.parameter.Parameter(torch.zeros(d_model,
                dtype=torch.float32)) 

    def forward(self, inp):
        residual = inp
        if self.pre_lnorm:
            inp = self.layer_norm(inp)

        gate_top_k_idx, gate_score = self.gate(inp)

87
88
89
        # TODO: merge replication into local_scatter
        inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k, 
                dim=0) # (BxLxtop_k) x d_model
Rick Ho's avatar
Rick Ho committed
90
91
92
93
94
95
96
97
98
99
100
101
102
        x = _fmoe_full_forward(inp, gate_top_k_idx, 
                [self.htoh4, self.h4toh], self.activation,
                self.num_expert, self.world_size)

        core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model 
        core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
        core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
        output = core_out + residual

        if not self.pre_lnorm:
            output = self.layer_norm(output)
        return output, self.bias