Unverified Commit 882aa077 authored by SMG's avatar SMG Committed by GitHub
Browse files

feat: Implement V2 FBCaching and Optimize Existing FBCache (#621)

* caching_v2

* rename fb cache and write docstring

* lint

* rename utils to fbcache

* no need maintain sana for caching
parent c547f3b9
import torch
from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.25)
# WarmUp
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana1.6b-int4.png")
import torch
from diffusers import FluxPipeline
from nunchaku.caching.diffusers_adapters.flux_v2 import apply_cache_on_pipe
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.12,
residual_diff_threshold_single=0.20,
)
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-{precision}.png")
......@@ -16,6 +16,7 @@ from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from ...caching import utils
from ..fbcache import cache_context, create_cache_context, get_current_cache_context
def apply_cache_on_transformer(
......@@ -85,7 +86,7 @@ def apply_cache_on_transformer(
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
cache_context = utils.get_current_cache_context()
cache_context = get_current_cache_context()
if cache_context is not None:
with (
unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks),
......@@ -136,7 +137,7 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, **kwargs):
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
with cache_context(create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
......
"""
V2 caching implementation using a separate forward function.
"""
import functools
from diffusers import DiffusionPipeline
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
from ..fbcache import cache_context, create_cache_context
from ..utils_v2 import cached_forward_v2
def apply_cache_on_transformer(
transformer: NunchakuFluxTransformer2DModelV2,
*,
use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float | None = None,
):
"""
Apply caching to transformer by replacing its forward method.
Args:
transformer: The NunchakuFluxTransformer2DModelV2 instance to apply caching to.
use_double_fb_cache: If True, applies a more precise cache mechanism for improved
accuracy in caching decisions.
residual_diff_threshold: Default threshold value for residual difference.
residual_diff_threshold_multi: Threshold for residual difference in multi-layer blocks.
Only used when use_double_fb_cache is True.
residual_diff_threshold_single: Threshold for residual difference in single-layer blocks.
Only used when use_double_fb_cache is True.
Returns:
The transformer with caching applied.
"""
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False):
# Already cached, just update thresholds
transformer.residual_diff_threshold_multi = residual_diff_threshold_multi
transformer.residual_diff_threshold_single = residual_diff_threshold_single
transformer.use_double_fb_cache = use_double_fb_cache
return transformer
# Store original forward method
transformer._original_forward = transformer.forward
# Set caching parameters
transformer.residual_diff_threshold_multi = residual_diff_threshold_multi
transformer.residual_diff_threshold_single = (
residual_diff_threshold_single if residual_diff_threshold_single is not None else -1.0
)
transformer.use_double_fb_cache = use_double_fb_cache
transformer.verbose = False
transformer.forward = cached_forward_v2.__get__(transformer, transformer.__class__)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, **kwargs):
"""
Apply caching to a Flux pipeline.
"""
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with cache_context(create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
"""
Caching utilities for transformer models.
Implements first-block caching to accelerate transformer inference by reusing computations
when input changes are minimal. Supports SANA and Flux architectures.
**Main Classes**
- :class:`CacheContext` : Manages cache buffers and incremental naming.
**Key Functions**
- :func:`get_buffer`, :func:`set_buffer` : Cache buffer management.
- :func:`cache_context` : Context manager for cache operations.
- :func:`are_two_tensors_similar` : Tensor similarity check.
- :func:`apply_prev_hidden_states_residual` : Applies cached residuals.
- :func:`get_can_use_cache` : Checks cache usability.
- :func:`check_and_apply_cache` : Main cache logic.
**Caching Strategy**
1. Compute the first transformer block.
2. Compare the residual with the cached residual.
3. If similar, reuse cached results for the remaining blocks; otherwise, recompute and update cache.
.. note::
Adapted from ParaAttention:
https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
"""
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, Optional, Tuple
import torch
@dataclasses.dataclass
class CacheContext:
"""
Manages cache buffers and incremental naming for transformer model inference.
Attributes
----------
buffers : Dict[str, torch.Tensor]
Stores cached tensor buffers.
incremental_name_counters : DefaultDict[str, int]
Counters for generating unique incremental cache entry names.
"""
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
"""
Generate an incremental cache entry name.
Parameters
----------
name : str, optional
Base name for the counter. If None, uses "default".
Returns
-------
str
Incremental name in the format ``"{name}_{counter}"``.
"""
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_name(self):
"""
Reset all incremental name counters.
After calling this, :meth:`get_incremental_name` will start from 0 for each name.
"""
self.incremental_name_counters.clear()
# @torch.compiler.disable # This is a torchscript feature
def get_buffer(self, name: str) -> Optional[torch.Tensor]:
"""
Retrieve a cached tensor buffer by name.
Parameters
----------
name : str
Name of the buffer to retrieve.
Returns
-------
torch.Tensor or None
The cached tensor if found, otherwise None.
"""
return self.buffers.get(name)
def set_buffer(self, name: str, buffer: torch.Tensor):
"""
Store a tensor buffer in the cache.
Parameters
----------
name : str
The name to associate with the buffer.
buffer : torch.Tensor
The tensor to cache.
"""
self.buffers[name] = buffer
def clear_buffers(self):
"""
Clear all cached tensor buffers.
Removes all stored tensors from the cache.
"""
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name: str) -> torch.Tensor:
"""
Retrieve a cached tensor buffer from the current cache context.
Parameters
----------
name : str
The name of the buffer to retrieve.
Returns
-------
torch.Tensor or None
The cached tensor if found, otherwise None.
Raises
------
AssertionError
If no cache context is currently active.
Examples
--------
>>> with cache_context(create_cache_context()):
... set_buffer("my_tensor", torch.randn(2, 3))
... cached = get_buffer("my_tensor")
"""
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name: str, buffer: torch.Tensor):
"""
Store a tensor buffer in the current cache context.
Parameters
----------
name : str
The name to associate with the buffer.
buffer : torch.Tensor
The tensor to cache.
Raises
------
AssertionError
If no cache context is currently active.
Examples
--------
>>> with cache_context(create_cache_context()):
... set_buffer("my_tensor", torch.randn(2, 3))
... cached = get_buffer("my_tensor")
"""
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
"""
Create a new :class:`CacheContext` for managing cached computations.
Returns
-------
CacheContext
A new cache context instance.
Examples
--------
>>> context = create_cache_context()
>>> with cache_context(context):
... # Cached operations here
... pass
"""
return CacheContext()
def get_current_cache_context():
"""
Get the currently active cache context.
Returns:
CacheContext or None: The current cache context if one is active, None otherwise
Example:
>>> with cache_context(create_cache_context()):
... current = get_current_cache_context()
... assert current is not None
"""
return _current_cache_context
@contextlib.contextmanager
def cache_context(cache_context):
"""
Context manager to set the active cache context.
Sets the global cache context for the duration of the ``with`` block, restoring the previous context on exit.
Parameters
----------
cache_context : CacheContext
The cache context to activate.
Yields
------
None
Examples
--------
>>> context = create_cache_context()
>>> with cache_context(context):
... set_buffer("key", torch.tensor([1, 2, 3]))
... cached = get_buffer("key")
"""
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
@torch.compiler.disable
def are_two_tensors_similar(t1: torch.Tensor, t2: torch.Tensor, *, threshold: float, parallelized: bool = False):
"""
Check if two tensors are similar based on relative L1 distance.
The relative distance is computed as
``mean(abs(t1 - t2)) / mean(abs(t1))`` and compared to ``threshold``.
Parameters
----------
t1 : torch.Tensor
First tensor.
t2 : torch.Tensor
Second tensor.
threshold : float
Similarity threshold. Tensors are similar if relative distance < threshold.
parallelized : bool, optional
Unused. For API compatibility.
Returns
-------
tuple of (bool, float)
- bool: True if tensors are similar, False otherwise.
- float: The computed relative L1 distance.
"""
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff_ratio = mean_diff / mean_t1
is_similar = diff_ratio < threshold
return is_similar, diff_ratio
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
mode: str = "multi",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply cached residuals to hidden states.
Parameters
----------
hidden_states : torch.Tensor
Current hidden states.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states (required for ``mode="multi"``).
mode : {"multi", "single"}, default: "multi"
Whether to apply residuals for Flux double blocks or single blocks.
Returns
-------
tuple or torch.Tensor
- If ``mode="multi"``: (updated_hidden_states, updated_encoder_hidden_states)
- If ``mode="single"``: updated_hidden_states
Raises
------
AssertionError
If required cached residuals are not found.
ValueError
If mode is not "multi" or "single".
"""
if mode == "multi":
hidden_states_residual = get_buffer("multi_hidden_states_residual")
assert hidden_states_residual is not None, "multi_hidden_states_residual must be set before"
hidden_states = hidden_states + hidden_states_residual
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is not None:
enc_hidden_res = get_buffer("multi_encoder_hidden_states_residual")
msg = "multi_encoder_hidden_states_residual must be set before"
assert enc_hidden_res is not None, msg
encoder_hidden_states = encoder_hidden_states + enc_hidden_res
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
elif mode == "single":
single_residual = get_buffer("single_hidden_states_residual")
msg = "single_hidden_states_residual must be set before"
assert single_residual is not None, msg
hidden_states = hidden_states + single_residual
hidden_states = hidden_states.contiguous()
return hidden_states
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
@torch.compiler.disable
def get_can_use_cache(
first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"
):
"""
Check if cached computations can be reused based on residual similarity.
Parameters
----------
first_hidden_states_residual : torch.Tensor
Current first block residual.
threshold : float
Similarity threshold for cache validity.
parallelized : bool, optional
Whether computation is parallelized. Default is False.
mode : {"multi", "single"}, optional
Caching mode. Default is "multi".
Returns
-------
tuple of (bool, float)
- bool: True if cache can be used (residuals are similar), False otherwise.
- float: The computed similarity difference, or threshold if no cache exists.
Raises
------
ValueError
If mode is not "multi" or "single".
"""
if mode == "multi":
buffer_name = "first_multi_hidden_states_residual"
elif mode == "single":
buffer_name = "first_single_hidden_states_residual"
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
prev_res = get_buffer(buffer_name)
if prev_res is None:
return torch.tensor(False, device=first_hidden_states_residual.device), torch.tensor(
threshold, device=first_hidden_states_residual.device
)
is_similar, diff = are_two_tensors_similar(
prev_res,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return is_similar, diff
def check_and_apply_cache(
*,
first_residual: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
threshold: float,
parallelized: bool,
mode: str,
verbose: bool,
call_remaining_fn,
remaining_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
"""
Check and apply cache based on residual similarity.
This function determines whether cached results can be used by comparing the
first block residuals. If the cache is valid, it applies cached computations;
otherwise, it computes new values and updates the cache.
Parameters
----------
first_residual : torch.Tensor
First block residual for similarity comparison.
hidden_states : torch.Tensor
Current hidden states.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states (required for "multi" mode).
threshold : float
Similarity threshold for cache validity.
parallelized : bool
Whether computation is parallelized.
mode : {"multi", "single"}
Caching mode.
verbose : bool
Whether to print caching status messages.
call_remaining_fn : callable
Function to call remaining transformer blocks.
remaining_kwargs : dict
Additional keyword arguments for `call_remaining_fn`.
Returns
-------
tuple
(updated_hidden_states, updated_encoder_hidden_states, threshold)
- updated_hidden_states (torch.Tensor)
- updated_encoder_hidden_states (torch.Tensor or None)
- threshold (float)
"""
can_use_cache, diff = get_can_use_cache(
first_residual,
threshold=threshold,
parallelized=parallelized,
mode=mode,
)
torch._dynamo.graph_break()
if can_use_cache:
if verbose:
diff_val = diff.item() if isinstance(diff, torch.Tensor) else diff
print(f"[{mode.upper()}] Cache hit! diff={diff_val:.4f}, " f"new threshold={threshold:.4f}")
out = apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states, mode=mode)
updated_h, updated_enc = out if isinstance(out, tuple) else (out, None)
return updated_h, updated_enc, threshold
old_threshold = threshold
if verbose:
diff_val = diff.item() if isinstance(diff, torch.Tensor) else diff
print(f"[{mode.upper()}] Cache miss. diff={diff_val:.4f}, " f"was={old_threshold:.4f} => now={threshold:.4f}")
if mode == "multi":
set_buffer("first_multi_hidden_states_residual", first_residual)
else:
set_buffer("first_single_hidden_states_residual", first_residual)
result = call_remaining_fn(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, **remaining_kwargs
)
if mode == "multi":
updated_h, updated_enc, hs_res, enc_res = result
set_buffer("multi_hidden_states_residual", hs_res)
set_buffer("multi_encoder_hidden_states_residual", enc_res)
return updated_h, updated_enc, threshold
elif mode == "single":
updated_cat_states, cat_res = result
set_buffer("single_hidden_states_residual", cat_res)
return updated_cat_states, None, threshold
raise ValueError(f"Unknown mode {mode}")
......@@ -6,19 +6,9 @@ when input changes are minimal. Supports SANA and Flux architectures.
**Main Classes**
- :class:`CacheContext` : Manages cache buffers and incremental naming.
- :class:`SanaCachedTransformerBlocks` : Cached transformer blocks for SANA models.
- :class:`FluxCachedTransformerBlocks` : Cached transformer blocks for Flux models.
**Key Functions**
- :func:`get_buffer`, :func:`set_buffer` : Cache buffer management.
- :func:`cache_context` : Context manager for cache operations.
- :func:`are_two_tensors_similar` : Tensor similarity check.
- :func:`apply_prev_hidden_states_residual` : Applies cached residuals.
- :func:`get_can_use_cache` : Checks cache usability.
- :func:`check_and_apply_cache` : Main cache logic.
**Caching Strategy**
1. Compute the first transformer block.
......@@ -30,463 +20,21 @@ when input changes are minimal. Supports SANA and Flux architectures.
https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
"""
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, Optional, Tuple
import torch
from torch import nn
from nunchaku.caching.fbcache import (
apply_prev_hidden_states_residual,
check_and_apply_cache,
get_can_use_cache,
set_buffer,
)
from nunchaku.models.transformers.utils import pad_tensor
num_transformer_blocks = 19 # FIXME
num_single_transformer_blocks = 38 # FIXME
@dataclasses.dataclass
class CacheContext:
"""
Manages cache buffers and incremental naming for transformer model inference.
Attributes
----------
buffers : Dict[str, torch.Tensor]
Stores cached tensor buffers.
incremental_name_counters : DefaultDict[str, int]
Counters for generating unique incremental cache entry names.
"""
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
"""
Generate an incremental cache entry name.
Parameters
----------
name : str, optional
Base name for the counter. If None, uses "default".
Returns
-------
str
Incremental name in the format ``"{name}_{counter}"``.
"""
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_name(self):
"""
Reset all incremental name counters.
After calling this, :meth:`get_incremental_name` will start from 0 for each name.
"""
self.incremental_name_counters.clear()
# @torch.compiler.disable # This is a torchscript feature
def get_buffer(self, name: str) -> Optional[torch.Tensor]:
"""
Retrieve a cached tensor buffer by name.
Parameters
----------
name : str
Name of the buffer to retrieve.
Returns
-------
torch.Tensor or None
The cached tensor if found, otherwise None.
"""
return self.buffers.get(name)
def set_buffer(self, name: str, buffer: torch.Tensor):
"""
Store a tensor buffer in the cache.
Parameters
----------
name : str
The name to associate with the buffer.
buffer : torch.Tensor
The tensor to cache.
"""
self.buffers[name] = buffer
def clear_buffers(self):
"""
Clear all cached tensor buffers.
Removes all stored tensors from the cache.
"""
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name: str) -> torch.Tensor:
"""
Retrieve a cached tensor buffer from the current cache context.
Parameters
----------
name : str
The name of the buffer to retrieve.
Returns
-------
torch.Tensor or None
The cached tensor if found, otherwise None.
Raises
------
AssertionError
If no cache context is currently active.
Examples
--------
>>> with cache_context(create_cache_context()):
... set_buffer("my_tensor", torch.randn(2, 3))
... cached = get_buffer("my_tensor")
"""
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name: str, buffer: torch.Tensor):
"""
Store a tensor buffer in the current cache context.
Parameters
----------
name : str
The name to associate with the buffer.
buffer : torch.Tensor
The tensor to cache.
Raises
------
AssertionError
If no cache context is currently active.
Examples
--------
>>> with cache_context(create_cache_context()):
... set_buffer("my_tensor", torch.randn(2, 3))
... cached = get_buffer("my_tensor")
"""
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
"""
Create a new :class:`CacheContext` for managing cached computations.
Returns
-------
CacheContext
A new cache context instance.
Examples
--------
>>> context = create_cache_context()
>>> with cache_context(context):
... # Cached operations here
... pass
"""
return CacheContext()
def get_current_cache_context():
"""
Get the currently active cache context.
Returns:
CacheContext or None: The current cache context if one is active, None otherwise
Example:
>>> with cache_context(create_cache_context()):
... current = get_current_cache_context()
... assert current is not None
"""
return _current_cache_context
@contextlib.contextmanager
def cache_context(cache_context):
"""
Context manager to set the active cache context.
Sets the global cache context for the duration of the ``with`` block, restoring the previous context on exit.
Parameters
----------
cache_context : CacheContext
The cache context to activate.
Yields
------
None
Examples
--------
>>> context = create_cache_context()
>>> with cache_context(context):
... set_buffer("key", torch.tensor([1, 2, 3]))
... cached = get_buffer("key")
"""
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
@torch.compiler.disable
def are_two_tensors_similar(t1: torch.Tensor, t2: torch.Tensor, *, threshold: float, parallelized: bool = False):
"""
Check if two tensors are similar based on relative L1 distance.
The relative distance is computed as
``mean(abs(t1 - t2)) / mean(abs(t1))`` and compared to ``threshold``.
Parameters
----------
t1 : torch.Tensor
First tensor.
t2 : torch.Tensor
Second tensor.
threshold : float
Similarity threshold. Tensors are similar if relative distance < threshold.
parallelized : bool, optional
Unused. For API compatibility.
Returns
-------
tuple of (bool, float)
- bool: True if tensors are similar, False otherwise.
- float: The computed relative L1 distance.
"""
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
return diff < threshold, diff
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
mode: str = "multi",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply cached residuals to hidden states.
Parameters
----------
hidden_states : torch.Tensor
Current hidden states.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states (required for ``mode="multi"``).
mode : {"multi", "single"}, default: "multi"
Whether to apply residuals for Flux double blocks or single blocks.
Returns
-------
tuple or torch.Tensor
- If ``mode="multi"``: (updated_hidden_states, updated_encoder_hidden_states)
- If ``mode="single"``: updated_hidden_states
Raises
------
AssertionError
If required cached residuals are not found.
ValueError
If mode is not "multi" or "single".
"""
if mode == "multi":
hidden_states_residual = get_buffer("multi_hidden_states_residual")
assert hidden_states_residual is not None, "multi_hidden_states_residual must be set before"
hidden_states = hidden_states + hidden_states_residual
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is not None:
enc_hidden_res = get_buffer("multi_encoder_hidden_states_residual")
msg = "multi_encoder_hidden_states_residual must be set before"
assert enc_hidden_res is not None, msg
encoder_hidden_states = encoder_hidden_states + enc_hidden_res
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
elif mode == "single":
single_residual = get_buffer("single_hidden_states_residual")
msg = "single_hidden_states_residual must be set before"
assert single_residual is not None, msg
hidden_states = hidden_states + single_residual
hidden_states = hidden_states.contiguous()
return hidden_states
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
@torch.compiler.disable
def get_can_use_cache(
first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"
):
"""
Check if cached computations can be reused based on residual similarity.
Parameters
----------
first_hidden_states_residual : torch.Tensor
Current first block residual.
threshold : float
Similarity threshold for cache validity.
parallelized : bool, optional
Whether computation is parallelized. Default is False.
mode : {"multi", "single"}, optional
Caching mode. Default is "multi".
Returns
-------
tuple of (bool, float)
- bool: True if cache can be used (residuals are similar), False otherwise.
- float: The computed similarity difference, or threshold if no cache exists.
Raises
------
ValueError
If mode is not "multi" or "single".
"""
if mode == "multi":
buffer_name = "first_multi_hidden_states_residual"
elif mode == "single":
buffer_name = "first_single_hidden_states_residual"
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
prev_res = get_buffer(buffer_name)
if prev_res is None:
return False, threshold
is_similar, diff = are_two_tensors_similar(
prev_res,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return is_similar, diff
def check_and_apply_cache(
*,
first_residual: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
threshold: float,
parallelized: bool,
mode: str,
verbose: bool,
call_remaining_fn,
remaining_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
"""
Check and apply cache based on residual similarity.
This function determines whether cached results can be used by comparing the
first block residuals. If the cache is valid, it applies cached computations;
otherwise, it computes new values and updates the cache.
Parameters
----------
first_residual : torch.Tensor
First block residual for similarity comparison.
hidden_states : torch.Tensor
Current hidden states.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states (required for "multi" mode).
threshold : float
Similarity threshold for cache validity.
parallelized : bool
Whether computation is parallelized.
mode : {"multi", "single"}
Caching mode.
verbose : bool
Whether to print caching status messages.
call_remaining_fn : callable
Function to call remaining transformer blocks.
remaining_kwargs : dict
Additional keyword arguments for `call_remaining_fn`.
Returns
-------
tuple
(updated_hidden_states, updated_encoder_hidden_states, threshold)
- updated_hidden_states (torch.Tensor)
- updated_encoder_hidden_states (torch.Tensor or None)
- threshold (float)
"""
can_use_cache, diff = get_can_use_cache(
first_residual,
threshold=threshold,
parallelized=parallelized,
mode=mode,
)
torch._dynamo.graph_break()
if can_use_cache:
if verbose:
print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, " f"new threshold={threshold:.4f}")
out = apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states, mode=mode)
updated_h, updated_enc = out if isinstance(out, tuple) else (out, None)
return updated_h, updated_enc, threshold
old_threshold = threshold
if verbose:
print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, " f"was={old_threshold:.4f} => now={threshold:.4f}")
if mode == "multi":
set_buffer("first_multi_hidden_states_residual", first_residual)
else:
set_buffer("first_single_hidden_states_residual", first_residual)
result = call_remaining_fn(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, **remaining_kwargs
)
if mode == "multi":
updated_h, updated_enc, hs_res, enc_res = result
set_buffer("multi_hidden_states_residual", hs_res)
set_buffer("multi_encoder_hidden_states_residual", enc_res)
return updated_h, updated_enc, threshold
elif mode == "single":
updated_cat_states, cat_res = result
set_buffer("single_hidden_states_residual", cat_res)
return updated_cat_states, None, threshold
raise ValueError(f"Unknown mode {mode}")
class SanaCachedTransformerBlocks(nn.Module):
"""
Caching wrapper for SANA transformer blocks.
......
"""
Caching utilities for V2 transformer models.
Implements first-block caching to accelerate transformer inference by reusing computations
when input changes are minimal. Supports Flux V2 architecture with double FB cache.
**Main Functions**
- :func:`cached_forward_v2` : Cached forward pass for V2 transformers.
- :func:`run_remaining_blocks_v2` : Process all remaining blocks (multi and single).
- :func:`run_remaining_multi_blocks_v2` : Process multi-head blocks only.
- :func:`run_remaining_single_blocks_v2` : Process single-head blocks only.
**Caching Strategy**
1. Compute the first transformer block.
2. Compare the residual with the cached residual.
3. If similar, reuse cached results for the remaining blocks; otherwise, recompute and update cache.
4. For double FB cache, repeat the process for single blocks.
.. note::
V2 implementation with standalone functions for improved modularity.
"""
from typing import Any, Dict, Optional, Union
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from nunchaku.caching.fbcache import check_and_apply_cache
from nunchaku.models.embeddings import pack_rotemb
from nunchaku.models.transformers.utils import pad_tensor
def cached_forward_v2(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
Cached forward function for V2 transformer with first-block caching.
Replaces the transformer's forward method to enable caching optimizations.
If residual_diff_threshold_multi < 0, caching is disabled.
"""
# If caching disabled, use original forward
if self.residual_diff_threshold_multi < 0.0:
return self._original_forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projections,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
return_dict=return_dict,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)
# Prepare inputs
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb
rotary_emb_txt = pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
original_hidden_states = hidden_states
# Process first block to get residual
first_block = self.transformer_blocks[0]
first_encoder_hidden_states, first_hidden_states = first_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=(rotary_emb_img, rotary_emb_txt),
joint_attention_kwargs=joint_attention_kwargs,
)
# Calculate residual for cache comparison
hidden_states = first_hidden_states
encoder_hidden_states = first_encoder_hidden_states
first_hidden_states_residual_multi = hidden_states - original_hidden_states
del original_hidden_states
# Setup remaining blocks function and apply caching
remaining_kwargs = {
"temb": temb,
"rotary_emb_img": rotary_emb_img,
"rotary_emb_txt": rotary_emb_txt,
"rotary_emb_single": rotary_emb_single,
"joint_attention_kwargs": joint_attention_kwargs,
"txt_tokens": txt_tokens,
}
if self.use_double_fb_cache:
call_remaining_fn = run_remaining_multi_blocks_v2
else:
call_remaining_fn = run_remaining_blocks_v2
hidden_states, encoder_hidden_states, _ = check_and_apply_cache(
first_residual=first_hidden_states_residual_multi,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi,
parallelized=False,
mode="multi",
verbose=self.verbose if hasattr(self, "verbose") else False,
call_remaining_fn=lambda hidden_states, encoder_hidden_states, **kw: call_remaining_fn(
self, hidden_states, encoder_hidden_states, **remaining_kwargs
),
remaining_kwargs={},
)
if self.use_double_fb_cache:
# Second stage caching for single blocks
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
original_cat = hidden_states
# Process first single block
first_block = self.single_transformer_blocks[0]
hidden_states = first_block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=rotary_emb_single,
joint_attention_kwargs=joint_attention_kwargs,
)
first_hidden_states_residual_single = hidden_states - original_cat
del original_cat
call_remaining_fn = run_remaining_single_blocks_v2
original_dtype = hidden_states.dtype
original_device = hidden_states.device
hidden_states, _, _ = check_and_apply_cache(
first_residual=first_hidden_states_residual_single,
hidden_states=hidden_states,
encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single,
parallelized=False,
mode="single",
verbose=self.verbose,
call_remaining_fn=lambda hidden_states, encoder_hidden_states, **kw: call_remaining_fn(
self, hidden_states, encoder_hidden_states, **remaining_kwargs
),
remaining_kwargs=remaining_kwargs,
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
hidden_states = hidden_states[:, txt_tokens:, ...]
hidden_states = hidden_states.to(original_dtype).to(original_device)
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def run_remaining_blocks_v2(
self,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
joint_attention_kwargs,
txt_tokens,
**kwargs,
):
"""
Process remaining transformer blocks (both multi and single).
Called when cache is invalid. Processes all blocks after the first one.
"""
original_dtype = hidden_states.dtype
original_device = hidden_states.device
original_h = hidden_states
original_enc = encoder_hidden_states
# Process remaining multi blocks
for block in self.transformer_blocks[1:]:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=(rotary_emb_img, rotary_emb_txt),
joint_attention_kwargs=joint_attention_kwargs,
)
# Concatenate encoder and decoder for single blocks
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# Process all single blocks
for block in self.single_transformer_blocks:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=rotary_emb_single,
joint_attention_kwargs=joint_attention_kwargs,
)
# Restore original dtype and device
hidden_states = hidden_states.to(original_dtype).to(original_device)
# Split concatenated result
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
# Ensure contiguous memory layout
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
# Calculate residuals
hs_residual = hidden_states - original_h
enc_residual = encoder_hidden_states - original_enc
return hidden_states, encoder_hidden_states, hs_residual, enc_residual
def run_remaining_multi_blocks_v2(
self,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
joint_attention_kwargs,
txt_tokens,
**kwargs,
):
"""
Process remaining multi-head transformer blocks only.
Used when double FB cache is enabled. Skips single blocks.
"""
original_h = hidden_states
original_enc = encoder_hidden_states
# Process remaining multi blocks
for block in self.transformer_blocks[1:]:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=(rotary_emb_img, rotary_emb_txt),
joint_attention_kwargs=joint_attention_kwargs,
)
# Ensure contiguous memory layout
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
# Calculate residuals
hs_residual = hidden_states - original_h
enc_residual = encoder_hidden_states - original_enc
return hidden_states, encoder_hidden_states, hs_residual, enc_residual
def run_remaining_single_blocks_v2(
self,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
joint_attention_kwargs,
txt_tokens,
**kwargs,
):
"""
Process remaining single-head transformer blocks.
Used for second stage of double FB cache.
"""
# Save original for residual calculation
original_hidden_states = hidden_states.clone()
# Process remaining single blocks (skip first)
for block in self.single_transformer_blocks[1:]:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=rotary_emb_single,
joint_attention_kwargs=joint_attention_kwargs,
)
hidden_states = hidden_states.contiguous()
hs_residual = hidden_states - original_hidden_states
return hidden_states, hs_residual
......@@ -12,10 +12,9 @@ import unittest
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from nunchaku.caching.utils import cache_context, create_cache_context
from nunchaku.models.ip_adapter.utils import undo_all_mods_on_transformer
from ....caching.fbcache import cache_context, create_cache_context
from ...ip_adapter import utils
from ...ip_adapter.utils import undo_all_mods_on_transformer
def apply_IPA_on_transformer(
......
"""
Test for V2 Flux double FB cache implementation.
Tests the NunchakuFluxTransformer2DModelV2 with double FB cache enabled.
"""
import gc
import os
import sys
import pytest
import torch
if __name__ == "__main__":
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from diffusers import FluxPipeline
from nunchaku.caching.diffusers_adapters.flux_v2 import apply_cache_on_pipe
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,expected_lpips",
[
(True, 0.09, 0.20, 1024, 1024, 30, 0.24 if get_precision() == "int4" else 0.165),
(True, 0.09, 0.15, 1024, 1024, 50, 0.24 if get_precision() == "int4" else 0.161),
],
)
def test_flux_dev_double_fb_cache_v2(
use_double_fb_cache: bool,
residual_diff_threshold_multi: float,
residual_diff_threshold_single: float,
height: int,
width: int,
num_inference_steps: int,
expected_lpips: float,
):
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
)
prompt = "A cat holding a sign that says hello world"
generator = torch.Generator("cuda").manual_seed(42)
image = pipeline(
prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=3.5,
generator=generator,
).images[0]
assert image is not None
assert image.size == (width, height)
del pipeline, transformer
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_v2_cache_verbose_logging():
"""Test V2 cache with verbose logging enabled."""
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
transformer.verbose = True
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
prompt = "A simple test image"
generator = torch.Generator("cuda").manual_seed(42)
image = pipeline(
prompt, num_inference_steps=5, height=512, width=512, guidance_scale=3.5, generator=generator
).images[0]
assert image is not None
assert image.size == (512, 512)
# Clean up
del pipeline, transformer
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"threshold_single",
[0.08, 0.10, 0.12, 0.15, 0.20], # Test different thresholds
)
def test_v2_threshold_variations(threshold_single: float):
"""Test V2 with different threshold_single values."""
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=threshold_single,
)
prompt = "A beautiful landscape"
generator = torch.Generator("cuda").manual_seed(42)
image = pipeline(
prompt, num_inference_steps=10, height=512, width=512, guidance_scale=3.5, generator=generator
).images[0]
assert image is not None
assert image.size == (512, 512)
del pipeline, transformer
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_v2_memory_usage():
"""Test V2 memory usage with cache enabled."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
precision = get_precision()
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.15,
)
# Measure memory before generation
torch.cuda.synchronize()
memory_before = torch.cuda.memory_allocated() / (1024**3) # GB
# Generate image
prompt = "Memory test image"
generator = torch.Generator("cuda").manual_seed(42)
image = pipeline(
prompt, num_inference_steps=20, height=1024, width=1024, guidance_scale=3.5, generator=generator
).images[0]
# Measure peak memory
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # GB
print(f"Memory before: {memory_before:.2f} GB")
print(f"Peak memory: {peak_memory:.2f} GB")
print(f"Memory increase: {peak_memory - memory_before:.2f} GB")
# V2 typically uses ~18GB based on our tests
assert peak_memory < 20, f"Peak memory {peak_memory:.2f} GB exceeds expected limit"
assert image is not None
assert image.size == (1024, 1024)
# Clean up
del pipeline, transformer
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
print("Running V2 double FB cache tests...")
test_flux_dev_double_fb_cache_v2(
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
height=512,
width=512,
num_inference_steps=10,
expected_lpips=0.24 if get_precision() == "int4" else 0.165,
)
print("✓ Basic test passed")
test_v2_cache_verbose_logging()
print("✓ Verbose logging test passed")
test_v2_memory_usage()
print("✓ Memory usage test passed")
print("\nAll V2 double FB cache tests passed!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment