Unverified Commit 6f3ac305 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] some shared parts between hooks + docs (#11968)

* update

* try test fix

* add missing link

* fix tests

* Update src/diffusers/hooks/first_block_cache.py

* make style
parent a6d9f6a1
......@@ -16,11 +16,11 @@ from typing import Optional
import torch
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
......@@ -35,6 +35,19 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
}
)
# Layers supported for group offloading and layerwise casting
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
......
......@@ -19,9 +19,9 @@ from typing import Any, Callable, List, Optional, Tuple
import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
from ._common import _ATTENTION_CLASSES
from .hooks import HookRegistry, ModelHook
......@@ -30,7 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
......@@ -489,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline to apply FasterCache to.
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
module (`torch.nn.Module`):
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FasterCacheConfig`):
The configuration to use for FasterCache.
Example:
......@@ -568,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
_apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules():
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
if not isinstance(submodule, _ATTENTION_CLASSES):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
......@@ -589,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
......
......@@ -192,6 +192,38 @@ class FBCBlockHook(ModelHook):
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
"""
Applies [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
to a given module.
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
to implement generically for a wide range of models and has been integrated first for experimental purposes.
Args:
module (`torch.nn.Module`):
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FirstBlockCacheConfig`):
The configuration to use for applying the FBCache method.
Example:
```python
>>> import torch
>>> from diffusers import CogView4Pipeline
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
>>> image.save("output.png")
```
"""
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
......
......@@ -23,6 +23,7 @@ import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
......@@ -39,13 +40,6 @@ _GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)
# fmt: on
......@@ -683,7 +677,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
group = ModuleGroup(
modules=[submodule],
......
......@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Type, Union
import torch
from ..utils import get_logger, is_peft_available, is_peft_version
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
......@@ -27,12 +28,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
......@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
......
......@@ -21,6 +21,12 @@ import torch
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from ._common import (
_ATTENTION_CLASSES,
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
)
from .hooks import HookRegistry, ModelHook
......@@ -28,10 +34,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
......@@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
......@@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
......
......@@ -1394,9 +1394,9 @@ else:
DevicePropertiesUserDict = UserDict
if is_torch_available():
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.group_offloading import (
_GROUP_ID_LAZY_LEAF,
_SUPPORTED_PYTORCH_LAYERS,
_compute_group_hash,
_find_parent_module_in_module_dict,
_gather_buffers_with_no_group_offloading_parent,
......@@ -1440,13 +1440,13 @@ if is_torch_available():
elif offload_type == "leaf_level":
# Handle leaf-level module groups
for name, submodule in module.named_modules():
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
# These groups will always have parameters, so a file is expected
expected_files.add(get_hashed_filename(name))
# Handle groups for non-leaf parameters/buffers
modules_with_group_offloading = {
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
}
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
......
......@@ -2109,14 +2109,15 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
......@@ -2167,10 +2168,10 @@ class PeftLoraLoaderMixinTests:
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
SUPPORTED_PYTORCH_LAYERS,
apply_layerwise_casting,
)
......@@ -2180,7 +2181,7 @@ class PeftLoraLoaderMixinTests:
def check_module(denoiser):
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
for name, module in denoiser.named_modules():
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
......
......@@ -1530,7 +1530,8 @@ class ModelTesterMixin:
@torch.no_grad()
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -1544,7 +1545,7 @@ class ModelTesterMixin:
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment