transformer.py 2.3 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
r'''
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear


class _Expert(nn.Module):
    r'''
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    '''
    def __init__(self, num_expert, d_model, d_hidden, activation):
        super().__init__()
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        r'''
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        '''
        x = self.htoh4(inp, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        return x


class FMoETransformerMLP(FMoE):
    r'''
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
    '''
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
        mp_group=None,
        activation=torch.nn.functional.gelu,
        gate=NaiveGate,
        top_k=2,
        pre_lnorm=False
    ):
        def expert_fn(inp, gate):
            return self.experts(inp, gate)
        super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
                world_size=world_size, mp_group=mp_group, expert_fn=expert_fn)
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)
        self.pre_lnorm = pre_lnorm
        self.layer_norm = nn.LayerNorm(d_model)
        self.mark_parallel_comm()

    def forward(self, inp: torch.Tensor):
        r'''
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
        '''
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
        if self.pre_lnorm:
            inp = self.layer_norm(inp)
        output = super().forward(inp) + inp
        if not self.pre_lnorm:
            output = self.layer_norm(output)
        return output.reshape(original_shape)