Commit b3f12860 authored by Bluear7878's avatar Bluear7878
Browse files

[Auto Sync] feat: double FB cache + adaptive mechanisms (#76)

* DoubleFBCache

* rename > DoubleFBCache to use_double_fb_cache
parent e4f8ae9b
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(
["A cat holding a sign that says hello world"],
num_inference_steps=50
).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
......@@ -7,16 +7,28 @@ from torch import nn
from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12):
def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_threshold(residual_diff_threshold)
transformer.cached_transformer_blocks[0].update_residual_diff_threshold(use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single)
return transformer
cached_transformer_blocks = nn.ModuleList(
[
utils.FluxCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
return_hidden_states_first=False,
)
]
......
This diff is collapsed.
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),],
)
def test_flux_dev_cache(
use_double_fb_cache: bool,
residual_diff_threshold_multi : float,
residual_diff_threshold_single : float,
height: int,
width: int,
num_inference_steps: int,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
expected_lpips=expected_lpips,
)
......@@ -13,6 +13,7 @@ from tqdm import tqdm
import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
......@@ -141,6 +142,9 @@ def run_test(
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False,
cache_threshold: float = 0,
use_double_fb_cache: bool = False,
residual_diff_threshold_multi : float = 0,
residual_diff_threshold_single : float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 4,
......@@ -259,6 +263,12 @@ def run_test(
precision_str += "-co"
if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}"
if use_double_fb_cache:
precision_str += "-dfb"
if residual_diff_threshold_multi > 0:
precision_str += f"-rdm{residual_diff_threshold_multi}"
if residual_diff_threshold_single > 0:
precision_str += f"-rds{residual_diff_threshold_single}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
if batch_size > 1:
......@@ -303,6 +313,14 @@ def run_test(
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
if use_double_fb_cache:
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single)
run_pipeline(
batch_size=batch_size,
dataset=dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment