first_block_cache.py 11 KB
Newer Older
Aryan's avatar
Aryan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Tuple, Union

import torch

from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager


logger = get_logger(__name__)  # pylint: disable=invalid-name

_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"


@dataclass
class FirstBlockCacheConfig:
    r"""
    Configuration for [First Block
    Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).

    Args:
        threshold (`float`, defaults to `0.05`):
            The threshold to determine whether or not a forward pass through all layers of the model is required. A
            higher threshold usually results in a forward pass through a lower number of layers and faster inference,
            but might lead to poorer generation quality. A lower threshold may not result in significant generation
            speedup. The threshold is compared against the absmean difference of the residuals between the current and
            cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
            is skipped.
    """

    threshold: float = 0.05


class FBCSharedBlockState(BaseState):
    def __init__(self) -> None:
        super().__init__()

        self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
        self.head_block_residual: torch.Tensor = None
        self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
        self.should_compute: bool = True

    def reset(self):
        self.tail_block_residuals = None
        self.should_compute = True


class FBCHeadBlockHook(ModelHook):
    _is_stateful = True

    def __init__(self, state_manager: StateManager, threshold: float):
        self.state_manager = state_manager
        self.threshold = threshold
        self._metadata = None

    def initialize_hook(self, module):
        unwrapped_module = unwrap_module(module)
        self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
        return module

    def new_forward(self, module: torch.nn.Module, *args, **kwargs):
        original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)

        output = self.fn_ref.original_forward(*args, **kwargs)
        is_output_tuple = isinstance(output, tuple)

        if is_output_tuple:
            hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
        else:
            hidden_states_residual = output - original_hidden_states

        shared_state: FBCSharedBlockState = self.state_manager.get_state()
        hidden_states = encoder_hidden_states = None
        should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
        shared_state.should_compute = should_compute

        if not should_compute:
            # Apply caching
            if is_output_tuple:
                hidden_states = (
                    shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
                )
            else:
                hidden_states = shared_state.tail_block_residuals[0] + output

            if self._metadata.return_encoder_hidden_states_index is not None:
                assert is_output_tuple
                encoder_hidden_states = (
                    shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
                )

            if is_output_tuple:
                return_output = [None] * len(output)
                return_output[self._metadata.return_hidden_states_index] = hidden_states
                return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
                return_output = tuple(return_output)
            else:
                return_output = hidden_states
            output = return_output
        else:
            if is_output_tuple:
                head_block_output = [None] * len(output)
                head_block_output[0] = output[self._metadata.return_hidden_states_index]
                head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
            else:
                head_block_output = output
            shared_state.head_block_output = head_block_output
            shared_state.head_block_residual = hidden_states_residual

        return output

    def reset_state(self, module):
        self.state_manager.reset()
        return module

    @torch.compiler.disable
    def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
        shared_state = self.state_manager.get_state()
        if shared_state.head_block_residual is None:
            return True
        prev_hidden_states_residual = shared_state.head_block_residual
        absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
        prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
        diff = (absmean / prev_hidden_states_absmean).item()
        return diff > self.threshold


class FBCBlockHook(ModelHook):
    def __init__(self, state_manager: StateManager, is_tail: bool = False):
        super().__init__()
        self.state_manager = state_manager
        self.is_tail = is_tail
        self._metadata = None

    def initialize_hook(self, module):
        unwrapped_module = unwrap_module(module)
        self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
        return module

    def new_forward(self, module: torch.nn.Module, *args, **kwargs):
        original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
        original_encoder_hidden_states = None
        if self._metadata.return_encoder_hidden_states_index is not None:
            original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
                "encoder_hidden_states", args, kwargs
            )

        shared_state = self.state_manager.get_state()

        if shared_state.should_compute:
            output = self.fn_ref.original_forward(*args, **kwargs)
            if self.is_tail:
                hidden_states_residual = encoder_hidden_states_residual = None
                if isinstance(output, tuple):
                    hidden_states_residual = (
                        output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
                    )
                    encoder_hidden_states_residual = (
                        output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
                    )
                else:
                    hidden_states_residual = output - shared_state.head_block_output
                shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
            return output

        if original_encoder_hidden_states is None:
            return_output = original_hidden_states
        else:
            return_output = [None, None]
            return_output[self._metadata.return_hidden_states_index] = original_hidden_states
            return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
            return_output = tuple(return_output)
        return return_output


def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    """
    Applies [First Block
    Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
    to a given module.

    First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
    to implement generically for a wide range of models and has been integrated first for experimental purposes.

    Args:
        module (`torch.nn.Module`):
            The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
            Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
        config (`FirstBlockCacheConfig`):
            The configuration to use for applying the FBCache method.

    Example:
        ```python
        >>> import torch
        >>> from diffusers import CogView4Pipeline
        >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig

        >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")

        >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))

        >>> prompt = "A photo of an astronaut riding a horse on mars"
        >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
        >>> image.save("output.png")
        ```
    """

Aryan's avatar
Aryan committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    state_manager = StateManager(FBCSharedBlockState, (), {})
    remaining_blocks = []

    for name, submodule in module.named_children():
        if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
            continue
        for index, block in enumerate(submodule):
            remaining_blocks.append((f"{name}.{index}", block))

    head_block_name, head_block = remaining_blocks.pop(0)
    tail_block_name, tail_block = remaining_blocks.pop(-1)

    logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
    _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)

    for name, block in remaining_blocks:
        logger.debug(f"Applying FBCBlockHook to '{name}'")
        _apply_fbc_block_hook(block, state_manager)

    logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
    _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)


def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
    registry = HookRegistry.check_if_exists_or_initialize(block)
    hook = FBCHeadBlockHook(state_manager, threshold)
    registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)


def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
    registry = HookRegistry.check_if_exists_or_initialize(block)
    hook = FBCBlockHook(state_manager, is_tail)
    registry.register_hook(hook, _FBC_BLOCK_HOOK)