layer_skip.py 11.7 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional

import torch

from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import (
    _ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
    _ATTENTION_CLASSES,
    _FEEDFORWARD_CLASSES,
    _get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook


logger = get_logger(__name__)  # pylint: disable=invalid-name

_LAYER_SKIP_HOOK = "layer_skip_hook"


# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
    r"""
    Configuration for skipping internal transformer blocks when executing a transformer model.

    Args:
        indices (`List[int]`):
            The indices of the layer to skip. This is typically the first layer in the transformer block.
        fqn (`str`, defaults to `"auto"`):
            The fully qualified name identifying the stack of transformer blocks. Typically, this is
            `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
            For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
            provide the correct fqn.
        skip_attention (`bool`, defaults to `True`):
            Whether to skip attention blocks.
        skip_ff (`bool`, defaults to `True`):
            Whether to skip feed-forward blocks.
        skip_attention_scores (`bool`, defaults to `False`):
            Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
            projections as the output of scaled dot product attention.
        dropout (`float`, defaults to `1.0`):
            The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
            meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
            skipped layers are fully retained, which is equivalent to not skipping any layers.
    """

    indices: List[int]
    fqn: str = "auto"
    skip_attention: bool = True
    skip_attention_scores: bool = False
    skip_ff: bool = True
    dropout: float = 1.0

    def __post_init__(self):
        if not (0 <= self.dropout <= 1):
            raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
        if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
            raise ValueError(
                "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
            )

    def to_dict(self):
        return asdict(self)

    @staticmethod
    def from_dict(data: dict) -> "LayerSkipConfig":
        return LayerSkipConfig(**data)


class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func is torch.nn.functional.scaled_dot_product_attention:
Aryan's avatar
Aryan committed
94
95
            query = kwargs.get("query", None)
            key = kwargs.get("key", None)
YiYi Xu's avatar
YiYi Xu committed
96
            value = kwargs.get("value", None)
Aryan's avatar
Aryan committed
97
98
99
100
101
102
103
104
105
106
            query = query if query is not None else args[0]
            key = key if key is not None else args[1]
            value = value if value is not None else args[2]
            # If the Q sequence length does not match KV sequence length, methods like
            # Perturbed Attention Guidance cannot be used (because the caller expects
            # the same sequence length as Q, but if we return V here, it will not match).
            # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
            # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
            if query.shape[2] == value.shape[2]:
                return value
YiYi Xu's avatar
YiYi Xu committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        return func(*args, **kwargs)


class AttentionProcessorSkipHook(ModelHook):
    def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
        self.skip_processor_output_fn = skip_processor_output_fn
        self.skip_attention_scores = skip_attention_scores
        self.dropout = dropout

    def new_forward(self, module: torch.nn.Module, *args, **kwargs):
        if self.skip_attention_scores:
            if not math.isclose(self.dropout, 1.0):
                raise ValueError(
                    "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
                )
            with AttentionScoreSkipFunctionMode():
                output = self.fn_ref.original_forward(*args, **kwargs)
        else:
            if math.isclose(self.dropout, 1.0):
                output = self.skip_processor_output_fn(module, *args, **kwargs)
            else:
                output = self.fn_ref.original_forward(*args, **kwargs)
                output = torch.nn.functional.dropout(output, p=self.dropout)
        return output


class FeedForwardSkipHook(ModelHook):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = dropout

    def new_forward(self, module: torch.nn.Module, *args, **kwargs):
        if math.isclose(self.dropout, 1.0):
            output = kwargs.get("hidden_states", None)
            if output is None:
                output = kwargs.get("x", None)
            if output is None and len(args) > 0:
                output = args[0]
        else:
            output = self.fn_ref.original_forward(*args, **kwargs)
            output = torch.nn.functional.dropout(output, p=self.dropout)
        return output


class TransformerBlockSkipHook(ModelHook):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = dropout

    def initialize_hook(self, module):
        self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
        return module

    def new_forward(self, module: torch.nn.Module, *args, **kwargs):
        if math.isclose(self.dropout, 1.0):
            original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
            if self._metadata.return_encoder_hidden_states_index is None:
                output = original_hidden_states
            else:
                original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
                    "encoder_hidden_states", args, kwargs
                )
                output = (original_hidden_states, original_encoder_hidden_states)
        else:
            output = self.fn_ref.original_forward(*args, **kwargs)
            output = torch.nn.functional.dropout(output, p=self.dropout)
        return output


def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
    r"""
    Apply layer skipping to internal layers of a transformer.

    Args:
        module (`torch.nn.Module`):
            The transformer model to which the layer skip hook should be applied.
        config (`LayerSkipConfig`):
            The configuration for the layer skip hook.

    Example:

    ```python
    >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig

    >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
    >>> apply_layer_skip_hook(transformer, config)
    ```
    """
    _apply_layer_skip_hook(module, config)


def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
    name = name or _LAYER_SKIP_HOOK

    if config.skip_attention and config.skip_attention_scores:
        raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
    if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
        raise ValueError(
            "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
        )

    if config.fqn == "auto":
        for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
            if hasattr(module, identifier):
                config.fqn = identifier
                break
        else:
            raise ValueError(
                "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
                "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
            )

    transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
    if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
        raise ValueError(
            f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
            f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
        )
    if len(config.indices) == 0:
        raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")

    blocks_found = False
    for i, block in enumerate(transformer_blocks):
        if i not in config.indices:
            continue

        blocks_found = True

        if config.skip_attention and config.skip_ff:
            logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
            registry = HookRegistry.check_if_exists_or_initialize(block)
            hook = TransformerBlockSkipHook(config.dropout)
            registry.register_hook(hook, name)

        elif config.skip_attention or config.skip_attention_scores:
            for submodule_name, submodule in block.named_modules():
                if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
                    logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
                    output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
                    registry = HookRegistry.check_if_exists_or_initialize(submodule)
                    hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
                    registry.register_hook(hook, name)

        if config.skip_ff:
            for submodule_name, submodule in block.named_modules():
                if isinstance(submodule, _FEEDFORWARD_CLASSES):
                    logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
                    registry = HookRegistry.check_if_exists_or_initialize(submodule)
                    hook = FeedForwardSkipHook(config.dropout)
                    registry.register_hook(hook, name)

    if not blocks_found:
        raise ValueError(
            f"Could not find any transformer blocks matching the provided indices {config.indices} and "
            f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
        )