layers.py 5.19 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
from .functions import *
Rick Ho's avatar
Rick Ho committed
2
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
3
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
4
5
6
7


class FMoELinear(nn.Module):
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Rick Ho's avatar
Rick Ho committed
8
        super(FMoELinear, self).__init__()
Rick Ho's avatar
Rick Ho committed
9
10
11
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
12
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
19
20
21
22
23
        self.reset_parameters()

    def reset_parameters(self):
        for i in range(self.num_expert):
            linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
            self.weight.data[i] = linear.weight.data

    def forward(self, inp, fwd_expert_count):
        return MOELinear.apply(inp, self.weight, fwd_expert_count)


Rick Ho's avatar
Rick Ho committed
24
25
class FMoENaiveGate(nn.Module):
    def __init__(self, d_model, num_expert, world_size, top_k=2):
Rick Ho's avatar
Rick Ho committed
26
        super(FMoENaiveGate, self).__init__()
Sengxian's avatar
Sengxian committed
27
        # print(f"gate: {num_expert * world_size}")
Rick Ho's avatar
Rick Ho committed
28
        self.gate = nn.Linear(d_model, num_expert * world_size)
Rick Ho's avatar
Rick Ho committed
29
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
30
31
32

    def forward(self, inp):
        gate = self.gate(inp)
33
34
35
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=False
        )  # [.. x top_k]
Rick Ho's avatar
Rick Ho committed
36
37
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

38
39
40
        # (BxL) x 1 x top_k
        gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
        gate_top_k_idx = gate_top_k_idx.view(-1)  # (BxLxtop_k)
Rick Ho's avatar
Rick Ho committed
41
42
43
44
45

        return gate_top_k_idx, gate_score


def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
46
47
48
49
50
51
52
53
54
55
    (
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
    ) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(
        inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size
    )
Rick Ho's avatar
Rick Ho committed
56
57
58
    for i, l in enumerate(linears):
        if i:
            x = activation(x)
59
        x = l(x, fwd_expert_count)
60
61
62
    x = MOEGather.apply(
        x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
    )
Rick Ho's avatar
Rick Ho committed
63
64
65
    return x


Rick Ho's avatar
Rick Ho committed
66
class FMoETransformerMLP(nn.Module):
67
68
69
70
71
72
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
Sengxian's avatar
Sengxian committed
73
74
75
        model_parallel_size=1,
        model_parallel_rank=1,
        group=None,
76
77
78
79
        activation=torch.nn.functional.gelu,
        top_k=2,
        pre_lnorm=False,
    ):
Rick Ho's avatar
Rick Ho committed
80
81
82
83
84
        super(FMoETransformerMLP, self).__init__()
        self.num_expert = num_expert
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.world_size = world_size
Sengxian's avatar
Sengxian committed
85
86
87
        self.model_parallel_size = model_parallel_size
        self.model_parallel_rank = model_parallel_rank
        self.group = group
Rick Ho's avatar
Rick Ho committed
88
89
        self.activation = activation
        self.pre_lnorm = pre_lnorm
Rick Ho's avatar
Rick Ho committed
90
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
91
92

        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
93
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
Rick Ho's avatar
Rick Ho committed
94

Sengxian's avatar
Sengxian committed
95
        # print(f"FMoETransformerMLP world_size: {world_size} num_expert: {num_expert}")
Rick Ho's avatar
Rick Ho committed
96
        self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
97
98

        self.layer_norm = nn.LayerNorm(d_model)
99
100
101
        self.bias = torch.nn.parameter.Parameter(
            torch.zeros(d_model, dtype=torch.float32)
        )
Rick Ho's avatar
Rick Ho committed
102

Sengxian's avatar
Sengxian committed
103
104
105
106
107
108
109
110
111
112
    def forward(self, inp: torch.Tensor):
        if self.num_expert != 1:
            B: int = inp.shape[1]
            local_batch_size = B // self.model_parallel_size
            batch_start = local_batch_size * self.model_parallel_rank
            batch_end = min(batch_start + local_batch_size, B)
            inp = inp[:, batch_start:batch_end, :].contiguous()
            # print(inp.shape)
            # print(f"mp_rank: {self.model_parallel_rank}, [{batch_start}, {batch_end})")

Rick Ho's avatar
Rick Ho committed
113
114
115
116
117
118
        residual = inp
        if self.pre_lnorm:
            inp = self.layer_norm(inp)

        gate_top_k_idx, gate_score = self.gate(inp)

119
        # TODO: merge replication into local_scatter
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        inp = inp.view(-1, self.d_model).repeat_interleave(
            repeats=self.top_k, dim=0
        )  # (BxLxtop_k) x d_model
        x = _fmoe_full_forward(
            inp,
            gate_top_k_idx,
            [self.htoh4, self.h4toh],
            self.activation,
            self.num_expert,
            self.world_size,
        )

        core_out = x.view(-1, self.top_k, self.d_model)  # (BxL) x top_k x d_model
        core_out = torch.bmm(gate_score, core_out)  # (BxL) x 1 x d_model
Rick Ho's avatar
Rick Ho committed
134
135
136
137
138
        core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
        output = core_out + residual

        if not self.pre_lnorm:
            output = self.layer_norm(output)
Sengxian's avatar
Sengxian committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        if self.num_expert != 1:
            world_size = self.model_parallel_size
            if world_size == 1:
                return output, self.bias

            rank = self.model_parallel_rank

            tensor_list = [torch.empty_like(output) for _ in range(world_size)]
            tensor_list[rank] = output
            torch.distributed.all_gather(tensor_list, output, group=self.group)

            # Note: torch.cat already creates a contiguous tensor.
            output = torch.cat(tensor_list, dim=1).contiguous()

Rick Ho's avatar
Rick Ho committed
154
        return output, self.bias