"vllm/vscode:/vscode.git/clone" did not exist on "86ac7bcf8483d87951a876cd2ed28341f60c95e0"
mlp.py 1.45 KB
Newer Older
1
2
3
4
5
import torch.nn as nn
import torch.nn.functional as F

from colossalai.kernel.jit import bias_gelu_impl

6
7
from .linear import Linear

8
9
10
11
12
13
14
15
16
17
18
19
20

class TransformerMLP(nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

    def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
        super(TransformerMLP, self).__init__()

        # Project to 4h.
21
        self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True)
22
23
24
25
26

        self.bias_gelu_fusion = fuse_gelu
        self.activation_func = F.gelu

        # Project back to h.
27
        self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True)
28
29
30
31
32
33
34
35

    def forward(self, hidden_states):
        # hidden states should be in the shape of [s, b, h]
        # it will be projects into [s, b, 4h]
        # and projected back to [s, b, h]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

        if self.bias_gelu_fusion:
36
            intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
37
        else:
38
            intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
39
40
41
42

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias