transformer.py 2.06 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
5
import torch
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
6
7
from .layers import FMoE
from .linear import FMoELinear
8
from .fastermoe.config import switch_from_env
Rick Ho's avatar
Rick Ho committed
9
10
11


class _Expert(nn.Module):
Sengxian's avatar
Sengxian committed
12
    r"""
Rick Ho's avatar
Rick Ho committed
13
14
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
Sengxian's avatar
Sengxian committed
15
16
    """

17
    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
Rick Ho's avatar
Rick Ho committed
18
        super().__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
19
20
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
Rick Ho's avatar
Rick Ho committed
21
22
23
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
24
        r"""
Rick Ho's avatar
Rick Ho committed
25
26
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
Sengxian's avatar
Sengxian committed
27
        """
Rick Ho's avatar
Rick Ho committed
28
29
30
31
32
33
34
        x = self.htoh4(inp, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        return x


class FMoETransformerMLP(FMoE):
Sengxian's avatar
Sengxian committed
35
    r"""
Rick Ho's avatar
Rick Ho committed
36
37
38
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
Sengxian's avatar
Sengxian committed
39
40
    """

Rick Ho's avatar
Rick Ho committed
41
42
43
44
45
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
46
        activation=torch.nn.GELU(),
Sengxian's avatar
Sengxian committed
47
        expert_dp_comm="none",
Rick Ho's avatar
Rick Ho committed
48
49
        expert_rank=0,
        **kwargs
Rick Ho's avatar
Rick Ho committed
50
    ):
51
52
53
54
55
        def one_expert(d_model):
            return _Expert(1, d_model, d_hidden, activation, rank=0)
        
        expert = one_expert
        super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
56
        self.mark_parallel_comm(expert_dp_comm)
Rick Ho's avatar
Rick Ho committed
57
58

    def forward(self, inp: torch.Tensor):
Sengxian's avatar
Sengxian committed
59
        r"""
Rick Ho's avatar
Rick Ho committed
60
61
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
Sengxian's avatar
Sengxian committed
62
        """
Rick Ho's avatar
Rick Ho committed
63
64
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
Rick Ho's avatar
Rick Ho committed
65
        output = super().forward(inp)
Rick Ho's avatar
Rick Ho committed
66
        return output.reshape(original_shape)