Unverified Commit 17ddd2d9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: remove duplicated 'cache_dir' argument when initializing Nunchaku models (#330)

* style: reformat some codes

* update

* remove the dev-scripts

* fix the lpips to pass the tests

* style: using black to refract some codes

* fix the lpips number
parent 91ad229f
......@@ -14,12 +14,14 @@ def apply_cache_on_transformer(
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
):
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_residual_diff_threshold(use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single)
transformer.cached_transformer_blocks[0].update_residual_diff_threshold(
use_double_fb_cache, residual_diff_threshold_multi, residual_diff_threshold_single
)
return transformer
cached_transformer_blocks = nn.ModuleList(
......
......@@ -4,6 +4,7 @@ import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, Optional, Tuple
import torch
from torch import nn
......@@ -85,13 +86,13 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, mode: str = "multi",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
mode: str = "multi",
) -> Tuple[torch.Tensor, torch.Tensor]:
if mode == "multi":
hidden_states_residual = get_buffer("multi_hidden_states_residual")
assert hidden_states_residual is not None, (
"multi_hidden_states_residual must be set before"
)
assert hidden_states_residual is not None, "multi_hidden_states_residual must be set before"
hidden_states = hidden_states + hidden_states_residual
hidden_states = hidden_states.contiguous()
......@@ -118,7 +119,9 @@ def apply_prev_hidden_states_residual(
@torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"):
def get_can_use_cache(
first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"
):
if mode == "multi":
buffer_name = "first_multi_hidden_states_residual"
elif mode == "single":
......@@ -162,20 +165,16 @@ def check_and_apply_cache(
if can_use_cache:
if verbose:
print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, "
f"new threshold={threshold:.4f}")
print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, " f"new threshold={threshold:.4f}")
out = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states, mode=mode
)
out = apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states, mode=mode)
updated_h, updated_enc = out if isinstance(out, tuple) else (out, None)
return updated_h, updated_enc, threshold
old_threshold = threshold
if verbose:
print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, "
f"was={old_threshold:.4f} => now={threshold:.4f}")
print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, " f"was={old_threshold:.4f} => now={threshold:.4f}")
if mode == "multi":
set_buffer("first_multi_hidden_states_residual", first_residual)
......@@ -183,9 +182,7 @@ def check_and_apply_cache(
set_buffer("first_single_hidden_states_residual", first_residual)
result = call_remaining_fn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
**remaining_kwargs
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, **remaining_kwargs
)
if mode == "multi":
......@@ -369,10 +366,7 @@ class FluxCachedTransformerBlocks(nn.Module):
return rotemb
def update_residual_diff_threshold(
self,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.12,
residual_diff_threshold_single=0.09
self, use_double_fb_cache=True, residual_diff_threshold_multi=0.12, residual_diff_threshold_single=0.09
):
self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi
......@@ -420,9 +414,7 @@ class FluxCachedTransformerBlocks(nn.Module):
total_tokens = txt_tokens + img_tokens
assert image_rotary_emb.shape[2] == 1 * total_tokens
image_rotary_emb = image_rotary_emb.reshape(
[1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]
)
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...]
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...]
rotary_emb_single = image_rotary_emb
......@@ -495,10 +487,7 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi,
parallelized=(
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
parallelized=(self.transformer is not None and getattr(self.transformer, "_is_parallelized", False)),
mode="multi",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn,
......@@ -515,9 +504,7 @@ class FluxCachedTransformerBlocks(nn.Module):
# DoubleFBCache
cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1)
original_cat = cat_hidden_states
cat_hidden_states = self.m.forward_single_layer(
0, cat_hidden_states, temb, rotary_emb_single
)
cat_hidden_states = self.m.forward_single_layer(0, cat_hidden_states, temb, rotary_emb_single)
first_hidden_states_residual_single = cat_hidden_states - original_cat
del original_cat
......@@ -529,10 +516,7 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states=cat_hidden_states,
encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single,
parallelized=(
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
parallelized=(self.transformer is not None and getattr(self.transformer, "_is_parallelized", False)),
mode="single",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn_single,
......
......@@ -47,17 +47,23 @@ class NunchakuModelLoaderMixin:
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
config, _, _ = cls.load_config(
pretrained_model_name_or_path,
subfolder=subfolder,
cache_dir=kwargs.get("cache_dir", None),
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
......
......@@ -9,12 +9,13 @@ from .utils import run_test
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),],
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),
],
)
def test_flux_dev_cache(
use_double_fb_cache: bool,
residual_diff_threshold_multi : float,
residual_diff_threshold_single : float,
residual_diff_threshold_multi: float,
residual_diff_threshold_single: float,
height: int,
width: int,
num_inference_steps: int,
......
......@@ -9,7 +9,7 @@ from .utils import run_test
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133),
(2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.156),
],
)
def test_flux_dev(
......
......@@ -8,8 +8,8 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113),
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
],
......
import pytest
from .utils import run_test
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
......
......@@ -12,8 +12,8 @@ from tqdm import tqdm
import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.lora.flux.compose import compose_lora
from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
......@@ -143,8 +143,8 @@ def run_test(
cpu_offload: bool = False,
cache_threshold: float = 0,
use_double_fb_cache: bool = False,
residual_diff_threshold_multi : float = 0,
residual_diff_threshold_single : float = 0,
residual_diff_threshold_multi: float = 0,
residual_diff_threshold_single: float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 4,
......@@ -319,7 +319,8 @@ def run_test(
pipeline,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single)
residual_diff_threshold_single=residual_diff_threshold_single,
)
run_pipeline(
batch_size=batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment