Unverified Commit 844221ae authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] FasterCache (#10163)



* init

* update

* update

* update

* make style

* update

* fix

* make it work with guidance distilled models

* update

* make fix-copies

* add tests

* update

* apply_faster_cache -> apply_fastercache

* fix

* reorder

* update

* refactor

* update docs

* add fastercache to CacheMixin

* update tests

* Apply suggestions from code review

* make style

* try to fix partial import error

* Apply style fixes

* raise warning

* update

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 9b2c0a7d
...@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig( ...@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
pipe.transformer.enable_cache(config) pipe.transformer.enable_cache(config)
``` ```
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
```python
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
```
### CacheMixin ### CacheMixin
[[autodoc]] CacheMixin [[autodoc]] CacheMixin
...@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config) ...@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
[[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast [[autodoc]] apply_pyramid_attention_broadcast
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
...@@ -131,8 +131,10 @@ except OptionalDependencyNotAvailable: ...@@ -131,8 +131,10 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["hooks"].extend( _import_structure["hooks"].extend(
[ [
"FasterCacheConfig",
"HookRegistry", "HookRegistry",
"PyramidAttentionBroadcastConfig", "PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_pyramid_attention_broadcast", "apply_pyramid_attention_broadcast",
] ]
) )
...@@ -703,7 +705,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -703,7 +705,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403 from .utils.dummy_pt_objects import * # noqa F403
else: else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .hooks import (
FasterCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
from .models import ( from .models import (
AllegroTransformer3DModel, AllegroTransformer3DModel,
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
......
...@@ -2,6 +2,7 @@ from ..utils import is_torch_available ...@@ -2,6 +2,7 @@ from ..utils import is_torch_available
if is_torch_available(): if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .group_offloading import apply_group_offloading from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
......
This diff is collapsed.
...@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook ...@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention) _ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
...@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig: ...@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"PyramidAttentionBroadcastConfig(" f"PyramidAttentionBroadcastConfig(\n"
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
...@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook): ...@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
return module return module
def apply_pyramid_attention_broadcast( def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
r""" r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
...@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook( ...@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
""" """
registry = HookRegistry.check_if_exists_or_initialize(module) registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, "pyramid_attention_broadcast") registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
...@@ -24,6 +24,7 @@ class CacheMixin: ...@@ -24,6 +24,7 @@ class CacheMixin:
Supported caching techniques: Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
""" """
_cache_config = None _cache_config = None
...@@ -59,17 +60,31 @@ class CacheMixin: ...@@ -59,17 +60,31 @@ class CacheMixin:
``` ```
""" """
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from ..hooks import (
FasterCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
if self.is_cache_enabled:
raise ValueError(
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig): if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config) apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
else: else:
raise ValueError(f"Cache config {type(config)} is not supported.") raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config self._cache_config = config
def disable_cache(self) -> None: def disable_cache(self) -> None:
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None: if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
...@@ -77,7 +92,11 @@ class CacheMixin: ...@@ -77,7 +92,11 @@ class CacheMixin:
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self) registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True) registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
else: else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
......
...@@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): ...@@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
" `from_numpy` is no longer required." " `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now." " Pass `output_type='pt' to use the new version now."
) )
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0: if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2") raise ValueError("embed_dim must be divisible by 2")
......
...@@ -37,7 +37,6 @@ from torch import Tensor, nn ...@@ -37,7 +37,6 @@ from torch import Tensor, nn
from typing_extensions import Self from typing_extensions import Self
from .. import __version__ from .. import __version__
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod from ..quantizers.quantization_config import QuantizationMethod
from ..utils import ( from ..utils import (
...@@ -504,6 +503,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -504,6 +503,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
non_blocking (`bool`, *optional*, defaults to `False`): non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking. If `True`, the weight casting operations are non-blocking.
""" """
from ..hooks import apply_layerwise_casting
user_provided_patterns = True user_provided_patterns = True
if skip_modules_pattern is None: if skip_modules_pattern is None:
...@@ -570,6 +570,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -570,6 +570,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
... ) ... )
``` ```
""" """
from ..hooks import apply_group_offloading
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream: if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = ( msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first " "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
......
...@@ -817,7 +817,7 @@ class LattePipeline(DiffusionPipeline): ...@@ -817,7 +817,7 @@ class LattePipeline(DiffusionPipeline):
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( noise_pred = self.transformer(
latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=current_timestep, timestep=current_timestep,
enable_temporal_attentions=enable_temporal_attentions, enable_temporal_attentions=enable_temporal_attentions,
......
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HookRegistry(metaclass=DummyObject): class HookRegistry(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -32,6 +47,10 @@ class PyramidAttentionBroadcastConfig(metaclass=DummyObject): ...@@ -32,6 +47,10 @@ class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs): def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"]) requires_backends(apply_pyramid_attention_broadcast, ["torch"])
......
...@@ -31,6 +31,7 @@ from diffusers.utils.testing_utils import ( ...@@ -31,6 +31,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_matches_attn_procs_length,
...@@ -42,7 +43,9 @@ from ..test_pipelines_common import ( ...@@ -42,7 +43,9 @@ from ..test_pipelines_common import (
enable_full_determinism() enable_full_determinism()
class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
......
...@@ -7,7 +7,13 @@ import torch ...@@ -7,7 +7,13 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers import (
AutoencoderKL,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
nightly, nightly,
...@@ -18,6 +24,7 @@ from diffusers.utils.testing_utils import ( ...@@ -18,6 +24,7 @@ from diffusers.utils.testing_utils import (
) )
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin, FluxIPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
...@@ -27,7 +34,11 @@ from ..test_pipelines_common import ( ...@@ -27,7 +34,11 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests( class FluxPipelineFastTests(
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
): ):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
...@@ -38,6 +49,14 @@ class FluxPipelineFastTests( ...@@ -38,6 +49,14 @@ class FluxPipelineFastTests(
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True test_group_offloading = True
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = FluxTransformer2DModel( transformer = FluxTransformer2DModel(
......
...@@ -21,6 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConf ...@@ -21,6 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConf
from diffusers import ( from diffusers import (
AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
HunyuanVideoPipeline, HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
...@@ -30,13 +31,20 @@ from diffusers.utils.testing_utils import ( ...@@ -30,13 +31,20 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
)
enable_full_determinism() enable_full_determinism()
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
...@@ -56,6 +64,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca ...@@ -56,6 +64,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True test_group_offloading = True
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = HunyuanVideoTransformer3DModel( transformer = HunyuanVideoTransformer3DModel(
......
...@@ -25,6 +25,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -25,6 +25,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
FasterCacheConfig,
LattePipeline, LattePipeline,
LatteTransformer3DModel, LatteTransformer3DModel,
PyramidAttentionBroadcastConfig, PyramidAttentionBroadcastConfig,
...@@ -40,13 +41,20 @@ from diffusers.utils.testing_utils import ( ...@@ -40,13 +41,20 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
)
enable_full_determinism() enable_full_determinism()
class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): class LattePipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
):
pipeline_class = LattePipeline pipeline_class = LattePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -69,6 +77,15 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste ...@@ -69,6 +77,15 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste
cross_attention_block_identifiers=["transformer_blocks"], cross_attention_block_identifiers=["transformer_blocks"],
) )
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
temporal_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
temporal_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
)
def get_dummy_components(self, num_layers: int = 1): def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = LatteTransformer3DModel( transformer = LatteTransformer3DModel(
......
...@@ -33,13 +33,13 @@ from diffusers.utils.testing_utils import ( ...@@ -33,13 +33,13 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism() enable_full_determinism()
class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -59,13 +59,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -59,13 +59,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True test_group_offloading = True
def get_dummy_components(self): def get_dummy_components(self, num_layers: int = 2):
torch.manual_seed(0) torch.manual_seed(0)
transformer = MochiTransformer3DModel( transformer = MochiTransformer3DModel(
patch_size=2, patch_size=2,
num_attention_heads=2, num_attention_heads=2,
attention_head_dim=8, attention_head_dim=8,
num_layers=2, num_layers=num_layers,
pooled_projection_dim=16, pooled_projection_dim=16,
in_channels=12, in_channels=12,
out_channels=None, out_channels=None,
......
...@@ -23,13 +23,16 @@ from diffusers import ( ...@@ -23,13 +23,16 @@ from diffusers import (
ConsistencyDecoderVAE, ConsistencyDecoderVAE,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
FasterCacheConfig,
KolorsPipeline, KolorsPipeline,
PyramidAttentionBroadcastConfig, PyramidAttentionBroadcastConfig,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
apply_faster_cache,
) )
from diffusers.hooks import apply_group_offloading from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
...@@ -2551,6 +2554,167 @@ class PyramidAttentionBroadcastTesterMixin: ...@@ -2551,6 +2554,167 @@ class PyramidAttentionBroadcastTesterMixin:
), "Outputs from normal inference and after disabling cache should not differ." ), "Outputs from normal inference and after disabling cache should not differ."
class FasterCacheTesterMixin:
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
)
def test_faster_cache_basic_warning_or_errors_raised(self):
components = self.get_dummy_components()
logger = logging.get_logger("diffusers.hooks.faster_cache")
logger.setLevel(logging.INFO)
# Check if warning is raise when no attention_weight_callback is provided
pipe = self.pipeline_class(**components)
with CaptureLogger(logger) as cap_logger:
config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None)
apply_faster_cache(pipe.transformer, config)
self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out)
# Check if error raised when unsupported tensor format used
pipe = self.pipeline_class(**components)
with self.assertRaises(ValueError):
config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC")
apply_faster_cache(pipe.transformer, config)
def test_faster_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
return pipe(**inputs)[0]
# Run inference without FasterCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache enabled
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
output = run_forward(pipe).flatten().flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
), "FasterCache outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
def test_faster_cache_state(self):
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 0
num_single_layers = 0
dummy_component_kwargs = {}
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
if "num_layers" in dummy_component_parameters:
num_layers = 2
dummy_component_kwargs["num_layers"] = num_layers
if "num_single_layers" in dummy_component_parameters:
num_single_layers = 2
dummy_component_kwargs["num_single_layers"] = num_single_layers
components = self.get_dummy_components(**dummy_component_kwargs)
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe.transformer.enable_cache(self.faster_cache_config)
expected_hooks = 0
if self.faster_cache_config.spatial_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
if self.faster_cache_config.temporal_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
# Check if faster_cache denoiser hook is attached
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
self.assertTrue(
hasattr(denoiser, "_diffusers_hook")
and isinstance(denoiser._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK), FasterCacheDenoiserHook),
"Hook should be of type FasterCacheDenoiserHook.",
)
# Check if all blocks have faster_cache block hook attached
count = 0
for name, module in denoiser.named_modules():
if hasattr(module, "_diffusers_hook"):
if name == "":
# Skip the root denoiser module
continue
count += 1
self.assertTrue(
isinstance(module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK), FasterCacheBlockHook),
"Hook should be of type FasterCacheBlockHook.",
)
self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.")
# Perform inference to ensure that states are updated correctly
def faster_cache_state_check_callback(pipe, i, t, kwargs):
for name, module in denoiser.named_modules():
if not hasattr(module, "_diffusers_hook"):
continue
if name == "":
# Root denoiser module
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
if not self.faster_cache_config.is_guidance_distilled:
self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.")
self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.")
else:
# Internal blocks
state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.")
self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.")
return {}
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
inputs["callback_on_step_end"] = faster_cache_state_check_callback
_ = pipe(**inputs)[0]
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
for name, module in denoiser.named_modules():
if not hasattr(module, "_diffusers_hook"):
continue
if name == "":
# Root denoiser module
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.")
self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.")
else:
# Internal blocks
state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
self.assertTrue(state.batch_size is None, "Batch size should be reset to None.")
self.assertTrue(state.cache is None, "Cache should be reset to None.")
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image. # reference image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment