transformer.py 1.97 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
5
import torch
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
6
7
from .layers import FMoE
from .linear import FMoELinear
Rick Ho's avatar
Rick Ho committed
8
9
10


class _Expert(nn.Module):
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
13
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
Sengxian's avatar
Sengxian committed
14
15
    """

16
    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
Rick Ho's avatar
Rick Ho committed
17
        super().__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
18
19
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
Rick Ho's avatar
Rick Ho committed
20
21
22
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
23
        r"""
Rick Ho's avatar
Rick Ho committed
24
25
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
Sengxian's avatar
Sengxian committed
26
        """
Rick Ho's avatar
Rick Ho committed
27
28
29
30
31
32
33
        x = self.htoh4(inp, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        return x


class FMoETransformerMLP(FMoE):
Sengxian's avatar
Sengxian committed
34
    r"""
Rick Ho's avatar
Rick Ho committed
35
36
37
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
Sengxian's avatar
Sengxian committed
38
39
    """

Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
        activation=torch.nn.GELU(),
Sengxian's avatar
Sengxian committed
46
        expert_dp_comm="none",
Rick Ho's avatar
Rick Ho committed
47
48
        expert_rank=0,
        **kwargs
Rick Ho's avatar
Rick Ho committed
49
    ):
Rick Ho's avatar
Rick Ho committed
50
        super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
Sengxian's avatar
Sengxian committed
51
        self.experts = _Expert(
Rick Ho's avatar
Rick Ho committed
52
            num_expert, d_model, d_hidden, activation, rank=expert_rank
Sengxian's avatar
Sengxian committed
53
        )
54
        self.mark_parallel_comm(expert_dp_comm)
Rick Ho's avatar
Rick Ho committed
55
56

    def forward(self, inp: torch.Tensor):
Sengxian's avatar
Sengxian committed
57
        r"""
Rick Ho's avatar
Rick Ho committed
58
59
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
Sengxian's avatar
Sengxian committed
60
        """
Rick Ho's avatar
Rick Ho committed
61
62
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
Rick Ho's avatar
Rick Ho committed
63
        output = super().forward(inp)
Rick Ho's avatar
Rick Ho committed
64
        return output.reshape(original_shape)