Commit b3f12860 authored by Bluear7878's avatar Bluear7878
Browse files

[Auto Sync] feat: double FB cache + adaptive mechanisms (#76)

* DoubleFBCache

* rename > DoubleFBCache to use_double_fb_cache
parent e4f8ae9b
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
...@@ -7,16 +7,28 @@ from torch import nn ...@@ -7,16 +7,28 @@ from torch import nn
from ...caching import utils from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12): def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False): if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_threshold(residual_diff_threshold) transformer.cached_transformer_blocks[0].update_residual_diff_threshold(use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single)
return transformer return transformer
cached_transformer_blocks = nn.ModuleList( cached_transformer_blocks = nn.ModuleList(
[ [
utils.FluxCachedTransformerBlocks( utils.FluxCachedTransformerBlocks(
transformer=transformer, transformer=transformer,
residual_diff_threshold=residual_diff_threshold, use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
return_hidden_states_first=False, return_hidden_states_first=False,
) )
] ]
......
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
import contextlib import contextlib
import dataclasses import dataclasses
from collections import defaultdict from collections import defaultdict
from typing import DefaultDict, Dict, Optional from typing import DefaultDict, Dict, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from nunchaku.models.transformers.utils import pad_tensor
num_transformer_blocks = 19 # FIXME
num_single_transformer_blocks = 38 # FIXME
@dataclasses.dataclass @dataclasses.dataclass
class CacheContext: class CacheContext:
...@@ -75,38 +79,127 @@ def cache_context(cache_context): ...@@ -75,38 +79,127 @@ def cache_context(cache_context):
def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False): def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
mean_diff = (t1 - t2).abs().mean() mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean() mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1 diff = (mean_diff / mean_t1).item()
return diff.item() < threshold return diff < threshold, diff
@torch.compiler.disable @torch.compiler.disable
def apply_prev_hidden_states_residual( def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, mode: str = "multi",
) -> tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states_residual = get_buffer("hidden_states_residual") if mode == "multi":
assert hidden_states_residual is not None, "hidden_states_residual must be set before" hidden_states_residual = get_buffer("multi_hidden_states_residual")
hidden_states = hidden_states_residual + hidden_states assert hidden_states_residual is not None, (
"multi_hidden_states_residual must be set before"
hidden_states = hidden_states.contiguous() )
if encoder_hidden_states is not None: hidden_states = hidden_states + hidden_states_residual
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual") hidden_states = hidden_states.contiguous()
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states if encoder_hidden_states is not None:
enc_hidden_res = get_buffer("multi_encoder_hidden_states_residual")
msg = "multi_encoder_hidden_states_residual must be set before"
assert enc_hidden_res is not None, msg
encoder_hidden_states = encoder_hidden_states + enc_hidden_res
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
elif mode == "single":
single_residual = get_buffer("single_hidden_states_residual")
msg = "single_hidden_states_residual must be set before"
assert single_residual is not None, msg
hidden_states = hidden_states + single_residual
hidden_states = hidden_states.contiguous()
return hidden_states
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
@torch.compiler.disable @torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False): def get_can_use_cache(first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual") if mode == "multi":
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar( buffer_name = "first_multi_hidden_states_residual"
prev_first_hidden_states_residual, elif mode == "single":
buffer_name = "first_single_hidden_states_residual"
else:
raise ValueError(f"Unknown mode {mode}; expected 'multi' or 'single'")
prev_res = get_buffer(buffer_name)
if prev_res is None:
return False, threshold
is_similar, diff = are_two_tensors_similar(
prev_res,
first_hidden_states_residual, first_hidden_states_residual,
threshold=threshold, threshold=threshold,
parallelized=parallelized, parallelized=parallelized,
) )
return can_use_cache return is_similar, diff
def check_and_apply_cache(
*,
first_residual: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
threshold: float,
parallelized: bool,
mode: str,
verbose: bool,
call_remaining_fn,
remaining_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
can_use_cache, diff = get_can_use_cache(
first_residual,
threshold=threshold,
parallelized=parallelized,
mode=mode,
)
torch._dynamo.graph_break()
if can_use_cache:
if verbose:
print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, "
f"new threshold={threshold:.4f}")
out = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states, mode=mode
)
updated_h, updated_enc = out if isinstance(out, tuple) else (out, None)
return updated_h, updated_enc, threshold
old_threshold = threshold
if verbose:
print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, "
f"was={old_threshold:.4f} => now={threshold:.4f}")
if mode == "multi":
set_buffer("first_multi_hidden_states_residual", first_residual)
else:
set_buffer("first_single_hidden_states_residual", first_residual)
result = call_remaining_fn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
**remaining_kwargs
)
if mode == "multi":
updated_h, updated_enc, hs_res, enc_res = result
set_buffer("multi_hidden_states_residual", hs_res)
set_buffer("multi_encoder_hidden_states_residual", enc_res)
return updated_h, updated_enc, threshold
elif mode == "single":
updated_cat_states, cat_res = result
set_buffer("single_hidden_states_residual", cat_res)
return updated_cat_states, None, threshold
raise ValueError(f"Unknown mode {mode}")
class SanaCachedTransformerBlocks(nn.Module): class SanaCachedTransformerBlocks(nn.Module):
...@@ -230,109 +323,339 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -230,109 +323,339 @@ class FluxCachedTransformerBlocks(nn.Module):
def __init__( def __init__(
self, self,
*, *,
transformer=None, transformer: nn.Module = None,
residual_diff_threshold, use_double_fb_cache: bool = True,
return_hidden_states_first=True, residual_diff_threshold_multi: float,
return_hidden_states_only=False, residual_diff_threshold_single: float,
return_hidden_states_first: bool = True,
return_hidden_states_only: bool = False,
verbose: bool = False, verbose: bool = False,
): ):
super().__init__() super().__init__()
self.transformer = transformer self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks self.transformer_blocks = transformer.transformer_blocks
self.single_transformer_blocks = transformer.single_transformer_blocks self.single_transformer_blocks = transformer.single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi
self.residual_diff_threshold_single = residual_diff_threshold_single
self.return_hidden_states_first = return_hidden_states_first self.return_hidden_states_first = return_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose self.verbose = verbose
def update_residual_diff_threshold(self, residual_diff_threshold=0.12): self.m = self.transformer_blocks[0].m
self.residual_diff_threshold = residual_diff_threshold self.dtype = torch.bfloat16 if self.m.isBF16() else torch.float16
self.device = transformer.device
@staticmethod
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
M = rotemb.shape[1]
D = rotemb.shape[2] * 2
msg_shape = "rotemb shape must be (B, M, D//2, 1, 2)"
assert rotemb.shape == (B, M, D // 2, 1, 2), msg_shape
assert M % 16 == 0
assert D % 8 == 0
rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
rotemb = rotemb.permute(0, 1, 3, 2, 4)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
rotemb = rotemb.contiguous()
rotemb = rotemb.view(B, M, D)
return rotemb
def update_residual_diff_threshold(
self,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.12,
residual_diff_threshold_single=0.09
):
self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi
self.residual_diff_threshold_single = residual_diff_threshold_single
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1: txt_tokens = encoder_hidden_states.shape[1]
if batch_size > 1: img_tokens = hidden_states.shape[1]
print("Batch size > 1 currently not supported")
first_transformer_block = self.transformer_blocks[0] original_dtype = hidden_states.dtype
encoder_hidden_states, hidden_states = first_transformer_block( original_device = hidden_states.device
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs
) hidden_states = hidden_states.to(self.dtype).to(self.device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
return ( controlnet_block_samples = (
hidden_states torch.stack(controlnet_block_samples).to(self.device) if len(controlnet_block_samples) > 0 else None
if self.return_hidden_states_only )
else ( if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0:
(hidden_states, encoder_hidden_states) controlnet_single_block_samples = (
if self.return_hidden_states_first torch.stack(controlnet_single_block_samples).to(self.device)
else (encoder_hidden_states, hidden_states) if len(controlnet_single_block_samples) > 0
) else None
) )
original_hidden_states = hidden_states assert image_rotary_emb.ndim == 6
first_transformer_block = self.transformer_blocks[0] assert image_rotary_emb.shape[0] == 1
encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at( assert image_rotary_emb.shape[1] == 1
0, hidden_states, encoder_hidden_states, *args, **kwargs # [1, tokens, head_dim/2, 1, 2] (sincos)
total_tokens = txt_tokens + img_tokens
assert image_rotary_emb.shape[2] == 1 * total_tokens
image_rotary_emb = image_rotary_emb.reshape(
[1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]
) )
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...]
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...]
rotary_emb_single = image_rotary_emb
first_hidden_states_residual = hidden_states - original_hidden_states rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
del original_hidden_states rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
can_use_cache = get_can_use_cache( if (self.residual_diff_threshold_multi < 0.0) or (batch_size > 1):
first_hidden_states_residual, if batch_size > 1 and self.verbose:
threshold=self.residual_diff_threshold, print("Batch size > 1 currently not supported")
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break() hidden_states = self.m.forward(
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states
)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
hidden_states_residual, temb,
encoder_hidden_states_residual, rotary_emb_img,
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs) rotary_emb_txt,
set_buffer("hidden_states_residual", hidden_states_residual) rotary_emb_single,
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual) controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer,
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
if self.return_hidden_states_only:
return hidden_states
if self.return_hidden_states_first:
return hidden_states, encoder_hidden_states
return encoder_hidden_states, hidden_states
remaining_kwargs = {
"temb": temb,
"rotary_emb_img": rotary_emb_img,
"rotary_emb_txt": rotary_emb_txt,
"rotary_emb_single": rotary_emb_single,
"controlnet_block_samples": controlnet_block_samples,
"controlnet_single_block_samples": controlnet_single_block_samples,
"txt_tokens": txt_tokens,
}
original_hidden_states = hidden_states
first_hidden_states, first_encoder_hidden_states = self.m.forward_layer(
0,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
)
hidden_states = first_hidden_states
encoder_hidden_states = first_encoder_hidden_states
first_hidden_states_residual_multi = hidden_states - original_hidden_states
del original_hidden_states
if self.use_double_fb_cache:
call_remaining_fn = self.call_remaining_multi_transformer_blocks
else:
call_remaining_fn = self.call_remaining_FBCache_transformer_blocks
torch._dynamo.graph_break() torch._dynamo.graph_break()
updated_h, updated_enc, threshold = check_and_apply_cache(
first_residual=first_hidden_states_residual_multi,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi,
parallelized=(
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
mode="multi",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn,
remaining_kwargs=remaining_kwargs,
)
self.residual_diff_threshold_multi = threshold
if not self.use_double_fb_cache:
if self.return_hidden_states_only:
return updated_h
if self.return_hidden_states_first:
return updated_h, updated_enc
return updated_enc, updated_h
# DoubleFBCache
cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1)
original_cat = cat_hidden_states
cat_hidden_states = self.m.forward_single_layer(
0, cat_hidden_states, temb, rotary_emb_single
)
return ( first_hidden_states_residual_single = cat_hidden_states - original_cat
hidden_states del original_cat
if self.return_hidden_states_only
else ( call_remaining_fn_single = self.call_remaining_single_transformer_blocks
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first updated_cat, _, threshold = check_and_apply_cache(
else (encoder_hidden_states, hidden_states) first_residual=first_hidden_states_residual_single,
) hidden_states=cat_hidden_states,
encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single,
parallelized=(
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
mode="single",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn_single,
remaining_kwargs=remaining_kwargs,
) )
self.residual_diff_threshold_single = threshold
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs): # torch._dynamo.graph_break()
first_transformer_block = self.transformer_blocks[0]
final_enc = updated_cat[:, :txt_tokens, ...]
final_h = updated_cat[:, txt_tokens:, ...]
final_h = final_h.to(original_dtype).to(original_device)
final_enc = final_enc.to(original_dtype).to(original_device)
if self.return_hidden_states_only:
return final_h
if self.return_hidden_states_first:
return final_h, final_enc
return final_enc, final_h
def call_remaining_FBCache_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=True,
txt_tokens=None,
):
original_dtype = hidden_states.dtype
original_device = hidden_states.device
original_hidden_states = hidden_states original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states original_encoder_hidden_states = encoder_hidden_states
encoder_hidden_states, hidden_states = first_transformer_block.forward( hidden_states = self.m.forward(
hidden_states=hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
skip_first_layer=True, temb,
*args, rotary_emb_img,
**kwargs, rotary_emb_txt,
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer,
) )
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states enc_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual return hidden_states, encoder_hidden_states, hidden_states_residual, enc_residual
def call_remaining_multi_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
txt_tokens=None,
):
start_idx = 1
original_hidden_states = hidden_states.clone()
original_encoder_hidden_states = encoder_hidden_states.clone()
for idx in range(start_idx, num_transformer_blocks):
hidden_states, encoder_hidden_states = self.m.forward_layer(
idx,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hs_res = hidden_states - original_hidden_states
enc_res = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hs_res, enc_res
def call_remaining_single_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
txt_tokens=None,
):
start_idx = 1
original_hidden_states = hidden_states.clone()
for idx in range(start_idx, num_single_transformer_blocks):
hidden_states = self.m.forward_single_layer(
idx,
hidden_states,
temb,
rotary_emb_single,
)
hidden_states = hidden_states.contiguous()
hs_res = hidden_states - original_hidden_states
return hidden_states, hs_res
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),],
)
def test_flux_dev_cache(
use_double_fb_cache: bool,
residual_diff_threshold_multi : float,
residual_diff_threshold_single : float,
height: int,
width: int,
num_inference_steps: int,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
expected_lpips=expected_lpips,
)
...@@ -13,6 +13,7 @@ from tqdm import tqdm ...@@ -13,6 +13,7 @@ from tqdm import tqdm
import nunchaku import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora from nunchaku.lora.flux.compose import compose_lora
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from ..data import get_dataset from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int from ..utils import already_generate, compute_lpips, hash_str_to_int
...@@ -141,6 +142,9 @@ def run_test( ...@@ -141,6 +142,9 @@ def run_test(
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16" attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False, cpu_offload: bool = False,
cache_threshold: float = 0, cache_threshold: float = 0,
use_double_fb_cache: bool = False,
residual_diff_threshold_multi : float = 0,
residual_diff_threshold_single : float = 0,
lora_names: str | list[str] | None = None, lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0, lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 4, max_dataset_size: int = 4,
...@@ -259,6 +263,12 @@ def run_test( ...@@ -259,6 +263,12 @@ def run_test(
precision_str += "-co" precision_str += "-co"
if cache_threshold > 0: if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}" precision_str += f"-cache{cache_threshold}"
if use_double_fb_cache:
precision_str += "-dfb"
if residual_diff_threshold_multi > 0:
precision_str += f"-rdm{residual_diff_threshold_multi}"
if residual_diff_threshold_single > 0:
precision_str += f"-rds{residual_diff_threshold_single}"
if i2f_mode is not None: if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}" precision_str += f"-i2f{i2f_mode}"
if batch_size > 1: if batch_size > 1:
...@@ -303,6 +313,14 @@ def run_test( ...@@ -303,6 +313,14 @@ def run_test(
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
if use_double_fb_cache:
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single)
run_pipeline( run_pipeline(
batch_size=batch_size, batch_size=batch_size,
dataset=dataset, dataset=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment