Unverified Commit 658e24e8 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Pyramid Attention Broadcast (#9562)



* start pyramid attention broadcast

* add coauthor
Co-Authored-By: default avatarXuanlei Zhao <43881818+oahzxl@users.noreply.github.com>

* update

* make style

* update

* make style

* add docs

* add tests

* update

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Pyramid Attention Broadcast rewrite + introduce hooks (#9826)

* rewrite implementation with hooks

* make style

* update

* merge pyramid-attention-rewrite-2

* make style

* remove changes from latte transformer

* revert docs changes

* better debug message

* add todos for future

* update tests

* make style

* cleanup

* fix

* improve log message; fix latte test

* refactor

* update

* update

* update

* revert changes to tests

* update docs

* update tests

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* fix flux test

* reorder

* refactor

* make fix-copies

* update docs

* fixes

* more fixes

* make style

* update tests

* update code example

* make fix-copies

* refactor based on reviews

* use maybe_free_model_hooks

* CacheMixin

* make style

* update

* add current_timestep property; update docs

* make fix-copies

* update

* improve tests

* try circular import fix

* apply suggestions from review

* address review comments

* Apply suggestions from code review

* refactor hook implementation

* add test suite for hooks

* PAB Refactor (#10667)

* update

* update

* update

---------
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>

* update

* fix remove hook behaviour

---------
Co-authored-by: default avatarXuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>
parent fb420664
...@@ -456,6 +456,10 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -456,6 +456,10 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -577,6 +581,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -577,6 +581,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
device = self._execution_device device = self._execution_device
...@@ -644,6 +649,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -644,6 +649,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype) latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
...@@ -678,6 +684,8 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -678,6 +684,8 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]
......
...@@ -602,6 +602,10 @@ class LattePipeline(DiffusionPipeline): ...@@ -602,6 +602,10 @@ class LattePipeline(DiffusionPipeline):
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -633,7 +637,7 @@ class LattePipeline(DiffusionPipeline): ...@@ -633,7 +637,7 @@ class LattePipeline(DiffusionPipeline):
clean_caption: bool = True, clean_caption: bool = True,
mask_feature: bool = True, mask_feature: bool = True,
enable_temporal_attentions: bool = True, enable_temporal_attentions: bool = True,
decode_chunk_size: Optional[int] = None, decode_chunk_size: int = 14,
) -> Union[LattePipelineOutput, Tuple]: ) -> Union[LattePipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -729,6 +733,7 @@ class LattePipeline(DiffusionPipeline): ...@@ -729,6 +733,7 @@ class LattePipeline(DiffusionPipeline):
negative_prompt_embeds, negative_prompt_embeds,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Default height and width to transformer # 2. Default height and width to transformer
...@@ -790,6 +795,7 @@ class LattePipeline(DiffusionPipeline): ...@@ -790,6 +795,7 @@ class LattePipeline(DiffusionPipeline):
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -850,6 +856,8 @@ class LattePipeline(DiffusionPipeline): ...@@ -850,6 +856,8 @@ class LattePipeline(DiffusionPipeline):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if output_type == "latents": if output_type == "latents":
deprecation_message = ( deprecation_message = (
"Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
...@@ -858,7 +866,7 @@ class LattePipeline(DiffusionPipeline): ...@@ -858,7 +866,7 @@ class LattePipeline(DiffusionPipeline):
output_type = "latent" output_type = "latent"
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents, video_length, decode_chunk_size=14) video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else: else:
video = latents video = latents
......
...@@ -21,8 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast ...@@ -21,8 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import Mochi1LoraLoaderMixin from ...loaders import Mochi1LoraLoaderMixin
from ...models.autoencoders import AutoencoderKLMochi from ...models import AutoencoderKLMochi, MochiTransformer3DModel
from ...models.transformers import MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
is_torch_xla_available, is_torch_xla_available,
...@@ -467,6 +466,10 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): ...@@ -467,6 +466,10 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property @property
def interrupt(self): def interrupt(self):
return self._interrupt return self._interrupt
...@@ -591,6 +594,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): ...@@ -591,6 +594,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
...@@ -660,6 +664,9 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): ...@@ -660,6 +664,9 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
if self.interrupt: if self.interrupt:
continue continue
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
# to make sure we're using the correct non-reversed timestep values.
self._current_timestep = 1000 - t
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
...@@ -705,6 +712,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): ...@@ -705,6 +712,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
......
...@@ -1133,11 +1133,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1133,11 +1133,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def maybe_free_model_hooks(self): def maybe_free_model_hooks(self):
r""" r"""
Function that offloads all components, removes all model hooks that were added when using Method that performs the following:
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function - Offloads all components.
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it - Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again.
functions correctly when applying enable_model_cpu_offload. In case the model has not been offloaded, this function is a no-op.
- Resets stateful diffusers hooks of denoiser components if they were added with
[`~hooks.HookRegistry.register_hook`].
Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions
correctly when applying `enable_model_cpu_offload`.
""" """
for component in self.components.values():
if hasattr(component, "_reset_stateful_cache"):
component._reset_stateful_cache()
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing # `enable_model_cpu_offload` has not be called, so silently do nothing
return return
......
...@@ -2,6 +2,40 @@ ...@@ -2,6 +2,40 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
class AllegroTransformer3DModel(metaclass=DummyObject): class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -197,6 +231,21 @@ class AutoencoderTiny(metaclass=DummyObject): ...@@ -197,6 +231,21 @@ class AutoencoderTiny(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class CacheMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CogVideoXTransformer3DModel(metaclass=DummyObject): class CogVideoXTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger
from diffusers.utils.testing_utils import CaptureLogger, torch_device
logger = get_logger(__name__) # pylint: disable=invalid-name
class DummyBlock(torch.nn.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.proj_in = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.proj_out = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.activation(x)
x = self.proj_out(x)
return x
class DummyModel(torch.nn.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.linear_2(x)
return x
class AddHook(ModelHook):
def __init__(self, value: int):
super().__init__()
self.value = value
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
logger.debug("AddHook pre_forward")
args = ((x + self.value) if torch.is_tensor(x) else x for x in args)
return args, kwargs
def post_forward(self, module, output):
logger.debug("AddHook post_forward")
return output
class MultiplyHook(ModelHook):
def __init__(self, value: int):
super().__init__()
self.value = value
def pre_forward(self, module, *args, **kwargs):
logger.debug("MultiplyHook pre_forward")
args = ((x * self.value) if torch.is_tensor(x) else x for x in args)
return args, kwargs
def post_forward(self, module, output):
logger.debug("MultiplyHook post_forward")
return output
def __repr__(self):
return f"MultiplyHook(value={self.value})"
class StatefulAddHook(ModelHook):
_is_stateful = True
def __init__(self, value: int):
super().__init__()
self.value = value
self.increment = 0
def pre_forward(self, module, *args, **kwargs):
logger.debug("StatefulAddHook pre_forward")
add_value = self.value + self.increment
self.increment += 1
args = ((x + add_value) if torch.is_tensor(x) else x for x in args)
return args, kwargs
def reset_state(self, module):
self.increment = 0
class SkipLayerHook(ModelHook):
def __init__(self, skip_layer: bool):
super().__init__()
self.skip_layer = skip_layer
def pre_forward(self, module, *args, **kwargs):
logger.debug("SkipLayerHook pre_forward")
return args, kwargs
def new_forward(self, module, *args, **kwargs):
logger.debug("SkipLayerHook new_forward")
if self.skip_layer:
return args[0]
return self.fn_ref.original_forward(*args, **kwargs)
def post_forward(self, module, output):
logger.debug("SkipLayerHook post_forward")
return output
class HookTests(unittest.TestCase):
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2
def setUp(self):
params = self.get_module_parameters()
self.model = DummyModel(**params)
self.model.to(torch_device)
def tearDown(self):
super().tearDown()
del self.model
gc.collect()
free_memory()
def get_module_parameters(self):
return {
"in_features": self.in_features,
"hidden_features": self.hidden_features,
"out_features": self.out_features,
"num_layers": self.num_layers,
}
def get_generator(self):
return torch.manual_seed(0)
def test_hook_registry(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
registry_repr = repr(registry)
expected_repr = (
"HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
)
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(registry_repr, expected_repr)
registry.remove_hook("add_hook")
self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
num_repeats = 3
for i in range(num_repeats):
result = self.model(input)
if i == 0:
output1 = result
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
registry.reset_stateful_hooks()
output2 = self.model(input)
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
self.assertTrue(torch.allclose(output1, output2))
def test_inference(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
output1 = self.model(input).mean().detach().cpu().item()
registry.remove_hook("multiply_hook")
new_input = input * 2
output2 = self.model(new_input).mean().detach().cpu().item()
registry.remove_hook("add_hook")
new_input = input * 2 + 1
output3 = self.model(new_input).mean().detach().cpu().item()
self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
input = torch.zeros(1, 4, device=torch_device)
output = self.model(input).mean().detach().cpu().item()
self.assertEqual(output, 0.0)
registry.remove_hook("skip_layer_hook")
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
def test_skip_layer_internal_block(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
input = torch.zeros(1, 4, device=torch_device)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
with self.assertRaises(RuntimeError) as cm:
self.model(input).mean().detach().cpu().item()
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
registry.remove_hook("skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
def test_invocation_order_stateful_first(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "add_hook")
registry.register_hook(AddHook(2), "add_hook_2")
registry.register_hook(MultiplyHook(3), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
logger = get_logger(__name__)
logger.setLevel("DEBUG")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"StatefulAddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_middle(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(2), "add_hook")
registry.register_hook(StatefulAddHook(1), "add_hook_2")
registry.register_hook(MultiplyHook(3), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
logger = get_logger(__name__)
logger.setLevel("DEBUG")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"MultiplyHook pre_forward\n"
"StatefulAddHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n")
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook_2")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
registry.register_hook(StatefulAddHook(3), "add_hook_2")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
logger = get_logger(__name__)
logger.setLevel("DEBUG")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
(
"StatefulAddHook pre_forward\n"
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n")
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
...@@ -34,13 +34,13 @@ from diffusers.utils.testing_utils import ( ...@@ -34,13 +34,13 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism() enable_full_determinism()
class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = AllegroPipeline pipeline_class = AllegroPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -59,14 +59,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -59,14 +59,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = AllegroTransformer3DModel( transformer = AllegroTransformer3DModel(
num_attention_heads=2, num_attention_heads=2,
attention_head_dim=12, attention_head_dim=12,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
num_layers=1, num_layers=num_layers,
cross_attention_dim=24, cross_attention_dim=24,
sample_width=8, sample_width=8,
sample_height=8, sample_height=8,
......
...@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import ( ...@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import ( from ..test_pipelines_common import (
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist, check_qkv_fusion_processors_exist,
to_np, to_np,
...@@ -41,7 +42,7 @@ from ..test_pipelines_common import ( ...@@ -41,7 +42,7 @@ from ..test_pipelines_common import (
enable_full_determinism() enable_full_determinism()
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -60,7 +61,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -60,7 +61,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel( transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
...@@ -72,7 +73,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -72,7 +73,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
out_channels=4, out_channels=4,
time_embed_dim=2, time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5 text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1, num_layers=num_layers,
sample_width=2, # latent width: 2 -> final width: 16 sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16 sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
......
...@@ -19,12 +19,15 @@ from diffusers.utils.testing_utils import ( ...@@ -19,12 +19,15 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FluxIPAdapterTesterMixin, FluxIPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist, check_qkv_fusion_processors_exist,
) )
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): class FluxPipelineFastTests(
unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin
):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
...@@ -33,13 +36,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte ...@@ -33,13 +36,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = FluxTransformer2DModel( transformer = FluxTransformer2DModel(
patch_size=1, patch_size=1,
in_channels=4, in_channels=4,
num_layers=1, num_layers=num_layers,
num_single_layers=1, num_single_layers=num_single_layers,
attention_head_dim=16, attention_head_dim=16,
num_attention_heads=2, num_attention_heads=2,
joint_attention_dim=32, joint_attention_dim=32,
......
...@@ -30,13 +30,13 @@ from diffusers.utils.testing_utils import ( ...@@ -30,13 +30,13 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism() enable_full_determinism()
class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
...@@ -55,15 +55,15 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -55,15 +55,15 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = HunyuanVideoTransformer3DModel( transformer = HunyuanVideoTransformer3DModel(
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
num_attention_heads=2, num_attention_heads=2,
attention_head_dim=10, attention_head_dim=10,
num_layers=1, num_layers=num_layers,
num_single_layers=1, num_single_layers=num_single_layers,
num_refiner_layers=1, num_refiner_layers=1,
patch_size=1, patch_size=1,
patch_size_t=1, patch_size_t=1,
......
...@@ -27,6 +27,7 @@ from diffusers import ( ...@@ -27,6 +27,7 @@ from diffusers import (
DDIMScheduler, DDIMScheduler,
LattePipeline, LattePipeline,
LatteTransformer3DModel, LatteTransformer3DModel,
PyramidAttentionBroadcastConfig,
) )
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -38,13 +39,13 @@ from diffusers.utils.testing_utils import ( ...@@ -38,13 +39,13 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism() enable_full_determinism()
class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
pipeline_class = LattePipeline pipeline_class = LattePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -54,11 +55,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -54,11 +55,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True test_layerwise_casting = True
def get_dummy_components(self): pab_config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
temporal_attention_block_skip_range=2,
cross_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 700),
temporal_attention_timestep_skip_range=(100, 800),
cross_attention_timestep_skip_range=(100, 800),
spatial_attention_block_identifiers=["transformer_blocks"],
temporal_attention_block_identifiers=["temporal_transformer_blocks"],
cross_attention_block_identifiers=["transformer_blocks"],
)
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0) torch.manual_seed(0)
transformer = LatteTransformer3DModel( transformer = LatteTransformer3DModel(
sample_size=8, sample_size=8,
num_layers=1, num_layers=num_layers,
patch_size=2, patch_size=2,
attention_head_dim=8, attention_head_dim=8,
num_attention_heads=3, num_attention_heads=3,
......
...@@ -24,10 +24,12 @@ from diffusers import ( ...@@ -24,10 +24,12 @@ from diffusers import (
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
KolorsPipeline, KolorsPipeline,
PyramidAttentionBroadcastConfig,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
...@@ -2322,6 +2324,141 @@ class SDXLOptionalComponentsTesterMixin: ...@@ -2322,6 +2324,141 @@ class SDXLOptionalComponentsTesterMixin:
self.assertLess(max_diff, expected_max_difference) self.assertLess(max_diff, expected_max_difference)
class PyramidAttentionBroadcastTesterMixin:
pab_config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
spatial_attention_block_identifiers=["transformer_blocks"],
)
def test_pyramid_attention_broadcast_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 0
num_single_layers = 0
dummy_component_kwargs = {}
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
if "num_layers" in dummy_component_parameters:
num_layers = 2
dummy_component_kwargs["num_layers"] = num_layers
if "num_single_layers" in dummy_component_parameters:
num_single_layers = 2
dummy_component_kwargs["num_single_layers"] = num_single_layers
components = self.get_dummy_components(**dummy_component_kwargs)
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
denoiser.enable_cache(self.pab_config)
expected_hooks = 0
if self.pab_config.spatial_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
if self.pab_config.temporal_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
if self.pab_config.cross_attention_block_skip_range is not None:
expected_hooks += num_layers + num_single_layers
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
count = 0
for module in denoiser.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
if hook is None:
continue
count += 1
self.assertTrue(
isinstance(hook, PyramidAttentionBroadcastHook),
"Hook should be of type PyramidAttentionBroadcastHook.",
)
self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.")
self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.")
# Perform dummy inference step to ensure state is updated
def pab_state_check_callback(pipe, i, t, kwargs):
for module in denoiser.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
if hook is None:
continue
self.assertTrue(
hook.state.cache is not None,
"Cache should have updated during inference.",
)
self.assertTrue(
hook.state.iteration == i + 1,
"Hook iteration state should have updated during inference.",
)
return {}
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 2
inputs["callback_on_step_end"] = pab_state_check_callback
pipe(**inputs)[0]
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
for module in denoiser.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
if hook is None:
continue
self.assertTrue(
hook.state.cache is None,
"Cache should be reset to None after inference.",
)
self.assertTrue(
hook.state.iteration == 0,
"Iteration should be reset to 0 after inference.",
)
def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2):
# We need to use higher tolerance because we are using a random model. With a converged/trained
# model, the tolerance can be lower.
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# Run inference without PAB
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
output = pipe(**inputs)[0]
original_image_slice = output.flatten()
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
# Run inference with PAB enabled
self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
denoiser.enable_cache(self.pab_config)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
output = pipe(**inputs)[0]
image_slice_pab_enabled = output.flatten()
image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:]))
# Run inference with PAB disabled
denoiser.disable_cache()
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
output = pipe(**inputs)[0]
image_slice_pab_disabled = output.flatten()
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=expected_atol
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image. # reference image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment