_common.py 2.07 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
Aryan's avatar
Aryan committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

YiYi Xu's avatar
YiYi Xu committed
15
16
17
18
from typing import Optional

import torch

19
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
Aryan's avatar
Aryan committed
20
21
22
from ..models.attention_processor import Attention, MochiAttention


23
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
YiYi Xu's avatar
YiYi Xu committed
24
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
Aryan's avatar
Aryan committed
25
26
27
28
29
30
31
32
33
34
35
36

_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")

_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
    {
        *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
        *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
        *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
    }
)
YiYi Xu's avatar
YiYi Xu committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
# Layers supported for group offloading and layerwise casting
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
    torch.nn.Conv1d,
    torch.nn.Conv2d,
    torch.nn.Conv3d,
    torch.nn.ConvTranspose1d,
    torch.nn.ConvTranspose2d,
    torch.nn.ConvTranspose3d,
    torch.nn.Linear,
    # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
    # because of double invocation of the same norm layer in CogVideoXLayerNorm
)

YiYi Xu's avatar
YiYi Xu committed
51
52
53
54
55
56

def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
    for submodule_name, submodule in module.named_modules():
        if submodule_name == fqn:
            return submodule
    return None