Unverified Commit 0454fbb3 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

First Block Cache (#11180)



* update

* modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code)

* remove debug logs

* update

* cache context for different batches of data

* fix hs residual bug for single return outputs; support ltx

* fix controlnet flux

* support flux, ltx i2v, ltx condition

* update

* update

* Update docs/source/en/api/cache.md

* Update src/diffusers/hooks/hooks.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* address review comments pt. 1

* address review comments pt. 2

* cache context refacotr; address review pt. 3

* address review comments

* metadata registration with decorators instead of centralized

* support cogvideox

* support mochi

* fix

* remove unused function

* remove central registry based on review

* update

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent cbc8ced2
...@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video: if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask, timestep=timestep,
video_coords=video_coords, encoder_attention_mask=prompt_attention_mask,
attention_kwargs=attention_kwargs, video_coords=video_coords,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
......
...@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ...@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask, timestep=timestep,
num_frames=latent_num_frames, encoder_attention_mask=prompt_attention_mask,
height=latent_height, num_frames=latent_num_frames,
width=latent_width, height=latent_height,
rope_interpolation_scale=rope_interpolation_scale, width=latent_width,
attention_kwargs=attention_kwargs, rope_interpolation_scale=rope_interpolation_scale,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
......
...@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): ...@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
noise_pred = self.transformer( with self.transformer.cache_context("cond_uncond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
encoder_hidden_states=prompt_embeds, hidden_states=latent_model_input,
timestep=timestep, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask, timestep=timestep,
attention_kwargs=attention_kwargs, encoder_attention_mask=prompt_attention_mask,
return_dict=False, attention_kwargs=attention_kwargs,
)[0] return_dict=False,
)[0]
# Mochi CFG + Sampling runs in FP32 # Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32) noise_pred = noise_pred.to(torch.float32)
......
...@@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype) latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0]) timestep = t.expand(latents.shape[0])
noise_pred = self.transformer( with self.transformer.cache_context("cond"):
hidden_states=latent_model_input, noise_pred = self.transformer(
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
......
...@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject): ...@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class FirstBlockCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HookRegistry(metaclass=DummyObject): class HookRegistry(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs): ...@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"]) requires_backends(apply_faster_cache, ["torch"])
def apply_first_block_cache(*args, **kwargs):
requires_backends(apply_first_block_cache, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs): def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"]) requires_backends(apply_pyramid_attention_broadcast, ["torch"])
......
...@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool: ...@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497). """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
......
...@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import ( ...@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FasterCacheTesterMixin, FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_matches_attn_procs_length,
...@@ -45,7 +46,11 @@ enable_full_determinism() ...@@ -45,7 +46,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests( class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
): ):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
......
...@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import ( ...@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FasterCacheTesterMixin, FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin, FluxIPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
...@@ -33,11 +34,12 @@ from ..test_pipelines_common import ( ...@@ -33,11 +34,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests( class FluxPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin, PipelineTesterMixin,
FluxIPAdapterTesterMixin, FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin, FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
): ):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
......
...@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import ( ...@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FasterCacheTesterMixin, FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin, PyramidAttentionBroadcastTesterMixin,
to_np, to_np,
...@@ -43,7 +44,11 @@ enable_full_determinism() ...@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests( class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
): ):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
......
...@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT ...@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism() enable_full_determinism()
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True test_layerwise_casting = True
test_group_offloading = True test_group_offloading = True
def get_dummy_components(self): def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel( transformer = LTXVideoTransformer3DModel(
in_channels=8, in_channels=8,
...@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4, num_attention_heads=4,
attention_head_dim=8, attention_head_dim=8,
cross_attention_dim=32, cross_attention_dim=32,
num_layers=1, num_layers=num_layers,
caption_channels=32, caption_channels=32,
) )
......
...@@ -32,13 +32,15 @@ from diffusers.utils.testing_utils import ( ...@@ -32,13 +32,15 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism() enable_full_determinism()
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): class MochiPipelineFastTests(
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
......
...@@ -33,6 +33,7 @@ from diffusers import ( ...@@ -33,6 +33,7 @@ from diffusers import (
) )
from diffusers.hooks import apply_group_offloading from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
...@@ -2648,7 +2649,7 @@ class FasterCacheTesterMixin: ...@@ -2648,7 +2649,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe() pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config) pipe.transformer.enable_cache(self.faster_cache_config)
output = run_forward(pipe).flatten().flatten() output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled # Run inference with FasterCache disabled
...@@ -2755,6 +2756,55 @@ class FasterCacheTesterMixin: ...@@ -2755,6 +2756,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.") self.assertTrue(state.cache is None, "Cache should be reset to None.")
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
# that will not satisfy the expected properties of the denoiser for caching to be effective
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
return pipe(**inputs)[0]
# Run inference without FirstBlockCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache enabled
pipe = create_pipe()
pipe.transformer.enable_cache(self.first_block_cache_config)
output = run_forward(pipe).flatten()
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
"FirstBlockCache outputs should not differ much."
)
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image. # reference image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment