mlp.py 1.25 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import torch
import torch.nn as nn

from fastvideo.v1.layers.activation import get_act_fn
from fastvideo.v1.layers.linear import ReplicatedLinear


class MLP(nn.Module):
    """
    MLP for DiT blocks, NO gated linear units
    """

    def __init__(
        self,
        input_dim: int,
        mlp_hidden_dim: int,
        output_dim: Optional[int] = None,
        bias: bool = True,
        act_type: str = "gelu_pytorch_tanh",
        dtype: Optional[torch.dtype] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.fc_in = ReplicatedLinear(
            input_dim,
            mlp_hidden_dim,  # For activation func like SiLU that need 2x width
            bias=bias,
            params_dtype=dtype)

        self.act = get_act_fn(act_type)
        if output_dim is None:
            output_dim = input_dim
        self.fc_out = ReplicatedLinear(mlp_hidden_dim,
                                       output_dim,
                                       bias=bias,
                                       params_dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc_in(x)
        x = self.act(x)
        x, _ = self.fc_out(x)
        return x