Unverified Commit 658e24e8 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Pyramid Attention Broadcast (#9562)



* start pyramid attention broadcast

* add coauthor
Co-Authored-By: default avatarXuanlei Zhao <43881818+oahzxl@users.noreply.github.com>

* update

* make style

* update

* make style

* add docs

* add tests

* update

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Pyramid Attention Broadcast rewrite + introduce hooks (#9826)

* rewrite implementation with hooks

* make style

* update

* merge pyramid-attention-rewrite-2

* make style

* remove changes from latte transformer

* revert docs changes

* better debug message

* add todos for future

* update tests

* make style

* cleanup

* fix

* improve log message; fix latte test

* refactor

* update

* update

* update

* revert changes to tests

* update docs

* update tests

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* fix flux test

* reorder

* refactor

* make fix-copies

* update docs

* fixes

* more fixes

* make style

* update tests

* update code example

* make fix-copies

* refactor based on reviews

* use maybe_free_model_hooks

* CacheMixin

* make style

* update

* add current_timestep property; update docs

* make fix-copies

* update

* improve tests

* try circular import fix

* apply suggestions from review

* address review comments

* Apply suggestions from code review

* refactor hook implementation

* add test suite for hooks

* PAB Refactor (#10667)

* update

* update

* update

---------
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>

* update

* fix remove hook behaviour

---------
Co-authored-by: default avatarXuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>
parent fb420664
...@@ -598,6 +598,8 @@ ...@@ -598,6 +598,8 @@
title: Attention Processor title: Attention Processor
- local: api/activations - local: api/activations
title: Custom activation functions title: Custom activation functions
- local: api/cache
title: Caching methods
- local: api/normalization - local: api/normalization
title: Custom normalization layers title: Custom normalization layers
- local: api/utilities - local: api/utilities
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
...@@ -28,6 +28,7 @@ from .utils import ( ...@@ -28,6 +28,7 @@ from .utils import (
_import_structure = { _import_structure = {
"configuration_utils": ["ConfigMixin"], "configuration_utils": ["ConfigMixin"],
"hooks": [],
"loaders": ["FromOriginalModelMixin"], "loaders": ["FromOriginalModelMixin"],
"models": [], "models": [],
"pipelines": [], "pipelines": [],
...@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable: ...@@ -75,6 +76,13 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else: else:
_import_structure["hooks"].extend(
[
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_pyramid_attention_broadcast",
]
)
_import_structure["models"].extend( _import_structure["models"].extend(
[ [
"AllegroTransformer3DModel", "AllegroTransformer3DModel",
...@@ -90,6 +98,7 @@ else: ...@@ -90,6 +98,7 @@ else:
"AutoencoderKLTemporalDecoder", "AutoencoderKLTemporalDecoder",
"AutoencoderOobleck", "AutoencoderOobleck",
"AutoencoderTiny", "AutoencoderTiny",
"CacheMixin",
"CogVideoXTransformer3DModel", "CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel", "CogView3PlusTransformer2DModel",
"ConsisIDTransformer3DModel", "ConsisIDTransformer3DModel",
...@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -588,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403 from .utils.dummy_pt_objects import * # noqa F403
else: else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import ( from .models import (
AllegroTransformer3DModel, AllegroTransformer3DModel,
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
...@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -602,6 +612,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderOobleck, AutoencoderOobleck,
AutoencoderTiny, AutoencoderTiny,
CacheMixin,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel, CogView3PlusTransformer2DModel,
ConsisIDTransformer3DModel, ConsisIDTransformer3DModel,
......
...@@ -2,4 +2,6 @@ from ..utils import is_torch_available ...@@ -2,4 +2,6 @@ from ..utils import is_torch_available
if is_torch_available(): if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
...@@ -30,6 +30,9 @@ class ModelHook: ...@@ -30,6 +30,9 @@ class ModelHook:
_is_stateful = False _is_stateful = False
def __init__(self):
self.fn_ref: "HookFunctionReference" = None
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r""" r"""
Hook that is executed when a model is initialized. Hook that is executed when a model is initialized.
...@@ -48,8 +51,6 @@ class ModelHook: ...@@ -48,8 +51,6 @@ class ModelHook:
module (`torch.nn.Module`): module (`torch.nn.Module`):
The module attached to this hook. The module attached to this hook.
""" """
module.forward = module._old_forward
del module._old_forward
return module return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
...@@ -99,6 +100,29 @@ class ModelHook: ...@@ -99,6 +100,29 @@ class ModelHook:
return module return module
class HookFunctionReference:
def __init__(self) -> None:
"""A container class that maintains mutable references to forward pass functions in a hook chain.
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
entire forward pass structure.
Attributes:
pre_forward: A callable that processes inputs before the main forward pass.
post_forward: A callable that processes outputs after the main forward pass.
forward: The current forward function in the hook chain.
original_forward: The original forward function, stored when a hook provides a custom new_forward.
The class enables hook removal by allowing updates to the forward chain through reference modification rather
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
be updated, preserving the execution order of the remaining hooks.
"""
self.pre_forward = None
self.post_forward = None
self.forward = None
self.original_forward = None
class HookRegistry: class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None: def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__() super().__init__()
...@@ -107,51 +131,71 @@ class HookRegistry: ...@@ -107,51 +131,71 @@ class HookRegistry:
self._module_ref = module_ref self._module_ref = module_ref
self._hook_order = [] self._hook_order = []
self._fn_refs = []
def register_hook(self, hook: ModelHook, name: str) -> None: def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys(): if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.") raise ValueError(
f"Hook with name {name} already exists in the registry. Please use a different name or "
if hasattr(self._module_ref, "_old_forward"): f"first remove the existing hook and then add a new one."
old_forward = self._module_ref._old_forward )
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward
self._module_ref = hook.initialize_hook(self._module_ref) self._module_ref = hook.initialize_hook(self._module_ref)
if hasattr(hook, "new_forward"): def create_new_forward(function_reference: HookFunctionReference):
rewritten_forward = hook.new_forward
def new_forward(module, *args, **kwargs): def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs) args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs) output = function_reference.forward(*args, **kwargs)
return hook.post_forward(module, output) return function_reference.post_forward(module, output)
else:
def new_forward(module, *args, **kwargs): return new_forward
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs) forward = self._module_ref.forward
return hook.post_forward(module, output)
fn_ref = HookFunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.forward = forward
if hasattr(hook, "new_forward"):
fn_ref.original_forward = forward
fn_ref.forward = functools.update_wrapper(
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)
rewritten_forward = create_new_forward(fn_ref)
self._module_ref.forward = functools.update_wrapper( self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward functools.partial(rewritten_forward, self._module_ref), rewritten_forward
) )
hook.fn_ref = fn_ref
self.hooks[name] = hook self.hooks[name] = hook
self._hook_order.append(name) self._hook_order.append(name)
self._fn_refs.append(fn_ref)
def get_hook(self, name: str) -> Optional[ModelHook]: def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys(): return self.hooks.get(name, None)
return None
return self.hooks[name]
def remove_hook(self, name: str, recurse: bool = True) -> None: def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys(): if name in self.hooks.keys():
num_hooks = len(self._hook_order)
hook = self.hooks[name] hook = self.hooks[name]
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]
old_forward = fn_ref.forward
if fn_ref.original_forward is not None:
old_forward = fn_ref.original_forward
if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].forward = old_forward
self._module_ref = hook.deinitalize_hook(self._module_ref) self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name] del self.hooks[name]
self._hook_order.remove(name) self._hook_order.pop(index)
self._fn_refs.pop(index)
if recurse: if recurse:
for module_name, module in self._module_ref.named_modules(): for module_name, module in self._module_ref.named_modules():
...@@ -161,7 +205,7 @@ class HookRegistry: ...@@ -161,7 +205,7 @@ class HookRegistry:
module._diffusers_hook.remove_hook(name, recurse=False) module._diffusers_hook.remove_hook(name, recurse=False)
def reset_stateful_hooks(self, recurse: bool = True) -> None: def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order: for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name] hook = self.hooks[hook_name]
if hook._is_stateful: if hook._is_stateful:
hook.reset_state(self._module_ref) hook.reset_state(self._module_ref)
...@@ -180,9 +224,13 @@ class HookRegistry: ...@@ -180,9 +224,13 @@ class HookRegistry:
return module._diffusers_hook return module._diffusers_hook
def __repr__(self) -> str: def __repr__(self) -> str:
hook_repr = "" registry_repr = ""
for i, hook_name in enumerate(self._hook_order): for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
hook_repr = self.hooks[hook_name].__repr__()
else:
hook_repr = self.hooks[hook_name].__class__.__name__
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
if i < len(self._hook_order) - 1: if i < len(self._hook_order) - 1:
hook_repr += "\n" registry_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)" return f"HookRegistry(\n{registry_repr}\n)"
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
class PyramidAttentionBroadcastConfig:
r"""
Configuration for Pyramid Attention Broadcast.
Args:
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific spatial attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific temporal attention broadcast is skipped before computing the attention
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
(i.e., old attention states will be re-used) before computing the new attention states again.
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the temporal attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
spatial_attention_block_skip_range: Optional[int] = None
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
# so not added for now)
def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig("
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
f" current_timestep_callback={self.current_timestep_callback}\n"
")"
)
class PyramidAttentionBroadcastState:
r"""
State for Pyramid Attention Broadcast.
Attributes:
iteration (`int`):
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
called before starting a new inference forward pass for PAB to work correctly.
cache (`Any`):
The cached output from the previous forward pass. This is used to re-use the attention states when the
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
"""
def __init__(self) -> None:
self.iteration = 0
self.cache = None
def reset(self):
self.iteration = 0
self.cache = None
def __repr__(self):
cache_repr = ""
if self.cache is None:
cache_repr = "None"
else:
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
class PyramidAttentionBroadcastHook(ModelHook):
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
_is_stateful = True
def __init__(
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
) -> None:
super().__init__()
self.timestep_skip_range = timestep_skip_range
self.block_skip_range = block_skip_range
self.current_timestep_callback = current_timestep_callback
def initialize_hook(self, module):
self.state = PyramidAttentionBroadcastState()
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
is_within_timestep_range = (
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
)
should_compute_attention = (
self.state.cache is None
or self.state.iteration == 0
or not is_within_timestep_range
or self.state.iteration % self.block_skip_range == 0
)
if should_compute_attention:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
output = self.state.cache
self.state.cache = output
self.state.iteration += 1
return output
def reset_state(self, module: torch.nn.Module) -> None:
self.state.reset()
return module
def apply_pyramid_attention_broadcast(
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
The configuration to use for Pyramid Attention Broadcast.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
```
"""
if config.current_timestep_callback is None:
raise ValueError(
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
)
if (
config.spatial_attention_block_skip_range is None
and config.temporal_attention_block_skip_range is None
and config.cross_attention_block_skip_range is None
):
logger.warning(
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
"To avoid this warning, please set one of the above parameters."
)
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
continue
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
def _apply_pyramid_attention_broadcast_on_attention_class(
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
) -> bool:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_temporal_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
and config.temporal_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_cross_attention = (
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
and config.cross_attention_block_skip_range is not None
and getattr(module, "is_cross_attention", False)
)
block_skip_range, timestep_skip_range, block_type = None, None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
block_type = "spatial"
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
block_type = "temporal"
elif is_cross_attention:
block_skip_range = config.cross_attention_block_skip_range
timestep_skip_range = config.cross_attention_timestep_skip_range
block_type = "cross"
if block_skip_range is None or timestep_skip_range is None:
logger.info(
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
f"block identifiers in the configuration."
)
return False
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
_apply_pyramid_attention_broadcast_hook(
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
)
return True
def _apply_pyramid_attention_broadcast_hook(
module: Union[Attention, MochiAttention],
timestep_skip_range: Tuple[int, int],
block_skip_range: int,
current_timestep_callback: Callable[[], int],
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
timestep_skip_range (`Tuple[int, int]`):
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
skipped if the current timestep is within the specified range.
block_skip_range (`int`):
The number of times a specific attention broadcast is skipped before computing the attention states to
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
attention states will be re-used) before computing the new attention states again.
current_timestep_callback (`Callable[[], int]`):
A callback function that returns the current inference timestep.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, "pyramid_attention_broadcast")
...@@ -39,6 +39,7 @@ if is_torch_available(): ...@@ -39,6 +39,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [ _import_structure["controlnets.controlnet_hunyuan"] = [
...@@ -109,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -109,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE, ConsistencyDecoderVAE,
VQModel, VQModel,
) )
from .cache_utils import CacheMixin
from .controlnets import ( from .controlnets import (
ControlNetModel, ControlNetModel,
ControlNetUnionModel, ControlNetUnionModel,
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.logging import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
class CacheMixin:
r"""
A class for enable/disabling caching techniques on diffusion models.
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
"""
_cache_config = None
@property
def is_cache_enabled(self) -> bool:
return self._cache_config is not None
def enable_cache(self, config) -> None:
r"""
Enable caching techniques on the model.
Args:
config (`Union[PyramidAttentionBroadcastConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)
```
"""
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
self._cache_config = None
def _reset_stateful_cache(self, recurse: bool = True) -> None:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
...@@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay ...@@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -156,7 +157,7 @@ class CogVideoXBlock(nn.Module): ...@@ -156,7 +157,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
""" """
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
import torch import torch
...@@ -19,13 +20,14 @@ from torch import nn ...@@ -19,13 +20,14 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle from ..normalization import AdaLayerNormSingle
class LatteTransformer3DModel(ModelMixin, ConfigMixin): class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
""" """
......
...@@ -24,6 +24,7 @@ from ...utils import is_torch_version, logging ...@@ -24,6 +24,7 @@ from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module): ...@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module):
return hidden_states return hidden_states
class AllegroTransformer3DModel(ModelMixin, ConfigMixin): class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
""" """
......
...@@ -35,6 +35,7 @@ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, Ad ...@@ -35,6 +35,7 @@ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, Ad
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -227,7 +228,7 @@ class FluxTransformerBlock(nn.Module): ...@@ -227,7 +228,7 @@ class FluxTransformerBlock(nn.Module):
class FluxTransformer2DModel( class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
): ):
""" """
The Transformer model introduced in Flux. The Transformer model introduced in Flux.
......
...@@ -25,6 +25,7 @@ from ...loaders import PeftAdapterMixin ...@@ -25,6 +25,7 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import ( from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,
...@@ -502,7 +503,7 @@ class HunyuanVideoTransformerBlock(nn.Module): ...@@ -502,7 +503,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r""" r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
......
...@@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay ...@@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_lay
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module): ...@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module):
@maybe_allow_in_graph @maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r""" r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
......
...@@ -683,6 +683,10 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -683,6 +683,10 @@ class AllegroPipeline(DiffusionPipeline):
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -815,6 +819,7 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -815,6 +819,7 @@ class AllegroPipeline(DiffusionPipeline):
negative_prompt_attention_mask, negative_prompt_attention_mask,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Default height and width to transformer # 2. Default height and width to transformer
...@@ -892,6 +897,7 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -892,6 +897,7 @@ class AllegroPipeline(DiffusionPipeline):
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -933,6 +939,8 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -933,6 +939,8 @@ class AllegroPipeline(DiffusionPipeline):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents) video = self.decode_latents(latents)
......
...@@ -494,6 +494,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -494,6 +494,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -627,6 +631,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -627,6 +631,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Default call parameters # 2. Default call parameters
...@@ -705,6 +710,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -705,6 +710,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -763,6 +769,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -763,6 +769,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5 # Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:] latents = latents[:, additional_frames:]
......
...@@ -540,6 +540,10 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -540,6 +540,10 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -680,6 +684,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -680,6 +684,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Default call parameters # 2. Default call parameters
...@@ -766,6 +771,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -766,6 +771,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -818,6 +824,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -818,6 +824,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
...@@ -591,6 +591,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -591,6 +591,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -728,6 +732,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -728,6 +732,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._current_timestep = None
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._interrupt = False self._interrupt = False
...@@ -815,6 +820,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -815,6 +820,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -877,6 +883,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -877,6 +883,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5 # Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:] latents = latents[:, additional_frames:]
......
...@@ -564,6 +564,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -564,6 +564,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -700,6 +704,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -700,6 +704,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Default call parameters # 2. Default call parameters
...@@ -786,6 +791,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -786,6 +791,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -844,6 +850,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -844,6 +850,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
......
...@@ -28,8 +28,7 @@ from transformers import ( ...@@ -28,8 +28,7 @@ from transformers import (
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models import AutoencoderKL, FluxTransformer2DModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -620,6 +619,10 @@ class FluxPipeline( ...@@ -620,6 +619,10 @@ class FluxPipeline(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -775,6 +778,7 @@ class FluxPipeline( ...@@ -775,6 +778,7 @@ class FluxPipeline(
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
...@@ -899,6 +903,7 @@ class FluxPipeline( ...@@ -899,6 +903,7 @@ class FluxPipeline(
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
if image_embeds is not None: if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
...@@ -957,9 +962,10 @@ class FluxPipeline( ...@@ -957,9 +962,10 @@ class FluxPipeline(
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if output_type == "latent": if output_type == "latent":
image = latents image = latents
else: else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment