Unverified Commit 0454fbb3 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

First Block Cache (#11180)



* update

* modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code)

* remove debug logs

* update

* cache context for different batches of data

* fix hs residual bug for single return outputs; support ltx

* fix controlnet flux

* support flux, ltx i2v, ltx condition

* update

* update

* Update docs/source/en/api/cache.md

* Update src/diffusers/hooks/hooks.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* address review comments pt. 1

* address review comments pt. 2

* cache context refacotr; address review pt. 3

* address review comments

* metadata registration with decorators instead of centralized

* support cogvideox

* support mochi

* fix

* remove unused function

* remove central registry based on review

* update

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent cbc8ced2
...@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate ...@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig [[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache [[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
...@@ -133,9 +133,11 @@ else: ...@@ -133,9 +133,11 @@ else:
_import_structure["hooks"].extend( _import_structure["hooks"].extend(
[ [
"FasterCacheConfig", "FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry", "HookRegistry",
"PyramidAttentionBroadcastConfig", "PyramidAttentionBroadcastConfig",
"apply_faster_cache", "apply_faster_cache",
"apply_first_block_cache",
"apply_pyramid_attention_broadcast", "apply_pyramid_attention_broadcast",
] ]
) )
...@@ -751,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -751,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .hooks import ( from .hooks import (
FasterCacheConfig, FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry, HookRegistry,
PyramidAttentionBroadcastConfig, PyramidAttentionBroadcastConfig,
apply_faster_cache, apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast, apply_pyramid_attention_broadcast,
) )
from .models import ( from .models import (
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available from ..utils import is_torch_available
if is_torch_available(): if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier]
if self._cached_parameter_indices is not None:
return args[self._cached_parameter_indices[identifier]]
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self._cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self._cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index]
class AttentionProcessorRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._register()
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_attention_processors_metadata()
class TransformerBlockRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._register()
metadata._cls = model_class
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_transformer_blocks_metadata()
def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
),
)
# CogView4AttnProcessor
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
# fmt: on
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
but might lead to poorer generation quality. A lower threshold may not result in significant generation
speedup. The threshold is compared against the absmean difference of the residuals between the current and
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
is skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, threshold: float):
self.state_manager = state_manager
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
else:
hidden_states_residual = output - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states = encoder_hidden_states = None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hidden_states = (
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
)
else:
hidden_states = shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
encoder_hidden_states = (
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return_output = tuple(return_output)
else:
return_output = hidden_states
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
shared_state.head_block_output = head_block_output
shared_state.head_block_residual = hidden_states_residual
return output
def reset_state(self, module):
self.state_manager.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual is None:
return True
prev_hidden_states_residual = shared_state.head_block_residual
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
original_encoder_hidden_states = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
shared_state = self.state_manager.get_state()
if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hidden_states_residual = encoder_hidden_states_residual = None
if isinstance(output, tuple):
hidden_states_residual = (
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
)
encoder_hidden_states_residual = (
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
)
else:
hidden_states_residual = output - shared_state.head_block_output
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
return output
if original_encoder_hidden_states is None:
return_output = original_hidden_states
else:
return_output = [None, None]
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return_output = tuple(return_output)
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Applying FBCBlockHook to '{name}'")
_apply_fbc_block_hook(block, state_manager)
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state_manager, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state_manager, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)
...@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple ...@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
from ..utils.logging import get_logger from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class StateManager:
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
self._state_cls = state_cls
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._state_cache = {}
self._current_context = None
def get_state(self):
if self._current_context is None:
raise ValueError("No context is set. Please set a context before retrieving the state.")
if self._current_context not in self._state_cache.keys():
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
return self._state_cache[self._current_context]
def set_context(self, name: str) -> None:
self._current_context = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._current_context = None
class ModelHook: class ModelHook:
r""" r"""
A hook that contains callbacks to be executed just before and after the forward method of a model. A hook that contains callbacks to be executed just before and after the forward method of a model.
...@@ -99,6 +132,14 @@ class ModelHook: ...@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module return module
def _set_context(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, StateManager):
attr.set_context(name)
return module
class HookFunctionReference: class HookFunctionReference:
def __init__(self) -> None: def __init__(self) -> None:
...@@ -211,9 +252,10 @@ class HookRegistry: ...@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref) hook.reset_state(self._module_ref)
if recurse: if recurse:
for module_name, module in self._module_ref.named_modules(): for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "": if module_name == "":
continue continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"): if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False) module._diffusers_hook.reset_stateful_hooks(recurse=False)
...@@ -223,6 +265,19 @@ class HookRegistry: ...@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module) module._diffusers_hook = cls(module)
return module._diffusers_hook return module._diffusers_hook
def _set_context(self, name: Optional[str] = None) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._set_context(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(name)
def __repr__(self) -> str: def __repr__(self) -> str:
registry_repr = "" registry_repr = ""
for i, hook_name in enumerate(self._hook_order): for i, hook_name in enumerate(self._hook_order):
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger from ..utils.logging import get_logger
...@@ -25,6 +27,7 @@ class CacheMixin: ...@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques: Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355) - [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
""" """
_cache_config = None _cache_config = None
...@@ -62,8 +65,10 @@ class CacheMixin: ...@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import ( from ..hooks import (
FasterCacheConfig, FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig, PyramidAttentionBroadcastConfig,
apply_faster_cache, apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast, apply_pyramid_attention_broadcast,
) )
...@@ -72,31 +77,36 @@ class CacheMixin: ...@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
) )
if isinstance(config, PyramidAttentionBroadcastConfig): if isinstance(config, FasterCacheConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config) apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else: else:
raise ValueError(f"Cache config {type(config)} is not supported.") raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config self._cache_config = config
def disable_cache(self) -> None: def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None: if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry = HookRegistry.check_if_exists_or_initialize(self)
registry = HookRegistry.check_if_exists_or_initialize(self) if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else: else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
...@@ -106,3 +116,15 @@ class CacheMixin: ...@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def cache_context(self, name: str):
r"""Context manager that provides additional methods for cache management."""
from ..hooks import HookRegistry
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)
yield
registry._set_context(None)
...@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
) )
block_samples = block_samples + (hidden_states,) block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = () single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func( encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, block,
hidden_states, hidden_states,
encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
) )
else: else:
hidden_states = block( encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
) )
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) single_block_samples = single_block_samples + (hidden_states,)
# controlnet block # controlnet block
controlnet_block_samples = () controlnet_block_samples = ()
......
...@@ -21,6 +21,7 @@ import torch.nn.functional as F ...@@ -21,6 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention from ..attention_processor import Attention
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor: ...@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module): class CogView4TransformerBlock(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
...@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504) hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph @maybe_allow_in_graph
...@@ -507,20 +512,21 @@ class FluxTransformer2DModel( ...@@ -507,20 +512,21 @@ class FluxTransformer2DModel(
) )
else: else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func( encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, block,
hidden_states, hidden_states,
encoder_hidden_states,
temb, temb,
image_rotary_emb, image_rotary_emb,
) )
else: else:
hidden_states = block( encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=joint_attention_kwargs,
...@@ -530,12 +536,7 @@ class FluxTransformer2DModel( ...@@ -530,12 +536,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None: if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control)) interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states) output = self.proj_out(hidden_states)
......
...@@ -22,6 +22,7 @@ import torch.nn.functional as F ...@@ -22,6 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention from ..attention_processor import Attention
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -249,6 +250,7 @@ class WanRotaryPosEmbed(nn.Module): ...@@ -249,6 +250,7 @@ class WanRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@maybe_allow_in_graph
class WanTransformerBlock(nn.Module): class WanTransformerBlock(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
image_rotary_emb=image_rotary_emb, timestep=timestep,
attention_kwargs=attention_kwargs, image_rotary_emb=image_rotary_emb,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
# perform guidance # perform guidance
......
...@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
image_rotary_emb=image_rotary_emb, timestep=timestep,
attention_kwargs=attention_kwargs, image_rotary_emb=image_rotary_emb,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
# perform guidance # perform guidance
......
...@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
ofs=ofs_emb, timestep=timestep,
image_rotary_emb=image_rotary_emb, ofs=ofs_emb,
attention_kwargs=attention_kwargs, image_rotary_emb=image_rotary_emb,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
# perform guidance # perform guidance
......
...@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
image_rotary_emb=image_rotary_emb, timestep=timestep,
attention_kwargs=attention_kwargs, image_rotary_emb=image_rotary_emb,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
# perform guidance # perform guidance
......
...@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ...@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]) timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer( with self.transformer.cache_context("cond"):
hidden_states=latent_model_input, noise_pred_cond = self.transformer(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=timestep, timestep=timestep,
original_size=original_size, original_size=original_size,
target_size=target_size, target_size=target_size,
...@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ...@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else: else:
noise_pred = noise_pred_cond noise_pred = noise_pred_cond
......
...@@ -912,32 +912,35 @@ class FluxPipeline( ...@@ -912,32 +912,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer( with self.transformer.cache_context("cond"):
hidden_states=latents, noise_pred = self.transformer(
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents, hidden_states=latents,
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=guidance, guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=prompt_embeds,
txt_ids=negative_text_ids, txt_ids=text_ids,
img_ids=latent_image_ids, img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs, joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
......
...@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer( with self.transformer.cache_context("cond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask, encoder_attention_mask=prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
guidance=guidance, guidance=guidance,
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
......
...@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ...@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask, timestep=timestep,
num_frames=latent_num_frames, encoder_attention_mask=prompt_attention_mask,
height=latent_height, num_frames=latent_num_frames,
width=latent_width, height=latent_height,
rope_interpolation_scale=rope_interpolation_scale, width=latent_width,
attention_kwargs=attention_kwargs, rope_interpolation_scale=rope_interpolation_scale,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment