transformer.py 2.69 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
r'''
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear


class _Expert(nn.Module):
    r'''
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    '''
15
    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
Rick Ho's avatar
Rick Ho committed
16
        super().__init__()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
17
18
19
20
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden,
                bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model,
                bias=True, rank=rank)
Rick Ho's avatar
Rick Ho committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        r'''
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        '''
        x = self.htoh4(inp, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        return x


class FMoETransformerMLP(FMoE):
    r'''
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
    '''
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
        mp_group=None,
        activation=torch.nn.functional.gelu,
        gate=NaiveGate,
        top_k=2,
50
        do_lnorm=False,
51
        pre_lnorm=False,
52
53
        expert_dp_comm='none',
        dropout=0.1
Rick Ho's avatar
Rick Ho committed
54
55
    ):
        super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
Rick Ho's avatar
Rick Ho committed
56
                top_k=top_k, world_size=world_size, mp_group=mp_group)
57
        self.dropout = nn.Dropout(dropout)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
58
59
        self.experts = _Expert(num_expert, d_model, d_hidden, activation,
                rank=self.mp_rank)
Rick Ho's avatar
Rick Ho committed
60
        self.pre_lnorm = pre_lnorm
61
62
63
64
65
        if do_lnorm:
            self.layer_norm = nn.LayerNorm(d_model)
            self.pre_lnorm = pre_lnorm
        else:
            self.pre_lnorm = None
66
        self.mark_parallel_comm(expert_dp_comm)
Rick Ho's avatar
Rick Ho committed
67
68
69
70
71
72
73
74

    def forward(self, inp: torch.Tensor):
        r'''
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
        '''
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
75
        if self.pre_lnorm is not None and self.pre_lnorm:
Rick Ho's avatar
Rick Ho committed
76
            inp = self.layer_norm(inp)
77
78
79
        output = super().forward(inp)
        output = self.dropout(output)
        output += inp
80
        if self.pre_lnorm is not None and not self.pre_lnorm:
Rick Ho's avatar
Rick Ho committed
81
82
            output = self.layer_norm(output)
        return output.reshape(original_shape)