transformer.py 2.31 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8
9
10
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear


class _Expert(nn.Module):
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
13
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
Sengxian's avatar
Sengxian committed
14
15
    """

16
    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
Rick Ho's avatar
Rick Ho committed
17
        super().__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
18
19
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
Rick Ho's avatar
Rick Ho committed
20
21
22
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
23
        r"""
Rick Ho's avatar
Rick Ho committed
24
25
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
Sengxian's avatar
Sengxian committed
26
        """
Rick Ho's avatar
Rick Ho committed
27
28
29
30
31
32
33
        x = self.htoh4(inp, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        return x


class FMoETransformerMLP(FMoE):
Sengxian's avatar
Sengxian committed
34
    r"""
Rick Ho's avatar
Rick Ho committed
35
36
37
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
Sengxian's avatar
Sengxian committed
38
39
    """

Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
46
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
        mp_group=None,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
47
        activation=torch.nn.GELU(),
Rick Ho's avatar
Rick Ho committed
48
49
        gate=NaiveGate,
        top_k=2,
Sengxian's avatar
Sengxian committed
50
        expert_dp_comm="none",
51
        gate_hook=None,
Colin's avatar
Colin committed
52
53
        mask=None,
        mask_dict=None,
Rick Ho's avatar
Rick Ho committed
54
    ):
Sengxian's avatar
Sengxian committed
55
56
57
58
59
60
61
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=gate,
            top_k=top_k,
            world_size=world_size,
            mp_group=mp_group,
62
            gate_hook=gate_hook,
Colin's avatar
Colin committed
63
64
            mask=mask,
            mask_dict=mask_dict
Sengxian's avatar
Sengxian committed
65
66
67
68
        )
        self.experts = _Expert(
            num_expert, d_model, d_hidden, activation, rank=self.mp_rank
        )
69
        self.mark_parallel_comm(expert_dp_comm)
Rick Ho's avatar
Rick Ho committed
70
71

    def forward(self, inp: torch.Tensor):
Sengxian's avatar
Sengxian committed
72
        r"""
Rick Ho's avatar
Rick Ho committed
73
74
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
Sengxian's avatar
Sengxian committed
75
        """
Rick Ho's avatar
Rick Ho committed
76
77
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
Rick Ho's avatar
Rick Ho committed
78
        output = super().forward(inp)
Rick Ho's avatar
Rick Ho committed
79
        return output.reshape(original_shape)