Unverified Commit 17ddd2d9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: remove duplicated 'cache_dir' argument when initializing Nunchaku models (#330)

* style: reformat some codes

* update

* remove the dev-scripts

* fix the lpips to pass the tests

* style: using black to refract some codes

* fix the lpips number
parent 91ad229f
...@@ -14,12 +14,14 @@ def apply_cache_on_transformer( ...@@ -14,12 +14,14 @@ def apply_cache_on_transformer(
residual_diff_threshold: float = 0.12, residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None, residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1, residual_diff_threshold_single: float = 0.1,
): ):
if residual_diff_threshold_multi is None: if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False): if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_residual_diff_threshold(use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single) transformer.cached_transformer_blocks[0].update_residual_diff_threshold(
use_double_fb_cache, residual_diff_threshold_multi, residual_diff_threshold_single
)
return transformer return transformer
cached_transformer_blocks = nn.ModuleList( cached_transformer_blocks = nn.ModuleList(
......
...@@ -4,6 +4,7 @@ import contextlib ...@@ -4,6 +4,7 @@ import contextlib
import dataclasses import dataclasses
from collections import defaultdict from collections import defaultdict
from typing import DefaultDict, Dict, Optional, Tuple from typing import DefaultDict, Dict, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -85,13 +86,13 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False): ...@@ -85,13 +86,13 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
@torch.compiler.disable @torch.compiler.disable
def apply_prev_hidden_states_residual( def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, mode: str = "multi", hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
mode: str = "multi",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if mode == "multi": if mode == "multi":
hidden_states_residual = get_buffer("multi_hidden_states_residual") hidden_states_residual = get_buffer("multi_hidden_states_residual")
assert hidden_states_residual is not None, ( assert hidden_states_residual is not None, "multi_hidden_states_residual must be set before"
"multi_hidden_states_residual must be set before"
)
hidden_states = hidden_states + hidden_states_residual hidden_states = hidden_states + hidden_states_residual
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
...@@ -118,7 +119,9 @@ def apply_prev_hidden_states_residual( ...@@ -118,7 +119,9 @@ def apply_prev_hidden_states_residual(
@torch.compiler.disable @torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"): def get_can_use_cache(
first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"
):
if mode == "multi": if mode == "multi":
buffer_name = "first_multi_hidden_states_residual" buffer_name = "first_multi_hidden_states_residual"
elif mode == "single": elif mode == "single":
...@@ -162,20 +165,16 @@ def check_and_apply_cache( ...@@ -162,20 +165,16 @@ def check_and_apply_cache(
if can_use_cache: if can_use_cache:
if verbose: if verbose:
print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, " print(f"[{mode.upper()}] Cache hit! diff={diff:.4f}, " f"new threshold={threshold:.4f}")
f"new threshold={threshold:.4f}")
out = apply_prev_hidden_states_residual( out = apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states, mode=mode)
hidden_states, encoder_hidden_states, mode=mode
)
updated_h, updated_enc = out if isinstance(out, tuple) else (out, None) updated_h, updated_enc = out if isinstance(out, tuple) else (out, None)
return updated_h, updated_enc, threshold return updated_h, updated_enc, threshold
old_threshold = threshold old_threshold = threshold
if verbose: if verbose:
print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, " print(f"[{mode.upper()}] Cache miss. diff={diff:.4f}, " f"was={old_threshold:.4f} => now={threshold:.4f}")
f"was={old_threshold:.4f} => now={threshold:.4f}")
if mode == "multi": if mode == "multi":
set_buffer("first_multi_hidden_states_residual", first_residual) set_buffer("first_multi_hidden_states_residual", first_residual)
...@@ -183,9 +182,7 @@ def check_and_apply_cache( ...@@ -183,9 +182,7 @@ def check_and_apply_cache(
set_buffer("first_single_hidden_states_residual", first_residual) set_buffer("first_single_hidden_states_residual", first_residual)
result = call_remaining_fn( result = call_remaining_fn(
hidden_states=hidden_states, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, **remaining_kwargs
encoder_hidden_states=encoder_hidden_states,
**remaining_kwargs
) )
if mode == "multi": if mode == "multi":
...@@ -369,10 +366,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -369,10 +366,7 @@ class FluxCachedTransformerBlocks(nn.Module):
return rotemb return rotemb
def update_residual_diff_threshold( def update_residual_diff_threshold(
self, self, use_double_fb_cache=True, residual_diff_threshold_multi=0.12, residual_diff_threshold_single=0.09
use_double_fb_cache=True,
residual_diff_threshold_multi=0.12,
residual_diff_threshold_single=0.09
): ):
self.use_double_fb_cache = use_double_fb_cache self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi self.residual_diff_threshold_multi = residual_diff_threshold_multi
...@@ -420,9 +414,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -420,9 +414,7 @@ class FluxCachedTransformerBlocks(nn.Module):
total_tokens = txt_tokens + img_tokens total_tokens = txt_tokens + img_tokens
assert image_rotary_emb.shape[2] == 1 * total_tokens assert image_rotary_emb.shape[2] == 1 * total_tokens
image_rotary_emb = image_rotary_emb.reshape( image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
[1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]
)
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...]
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...]
rotary_emb_single = image_rotary_emb rotary_emb_single = image_rotary_emb
...@@ -495,10 +487,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -495,10 +487,7 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi, threshold=self.residual_diff_threshold_multi,
parallelized=( parallelized=(self.transformer is not None and getattr(self.transformer, "_is_parallelized", False)),
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
mode="multi", mode="multi",
verbose=self.verbose, verbose=self.verbose,
call_remaining_fn=call_remaining_fn, call_remaining_fn=call_remaining_fn,
...@@ -515,9 +504,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -515,9 +504,7 @@ class FluxCachedTransformerBlocks(nn.Module):
# DoubleFBCache # DoubleFBCache
cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1) cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1)
original_cat = cat_hidden_states original_cat = cat_hidden_states
cat_hidden_states = self.m.forward_single_layer( cat_hidden_states = self.m.forward_single_layer(0, cat_hidden_states, temb, rotary_emb_single)
0, cat_hidden_states, temb, rotary_emb_single
)
first_hidden_states_residual_single = cat_hidden_states - original_cat first_hidden_states_residual_single = cat_hidden_states - original_cat
del original_cat del original_cat
...@@ -529,10 +516,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -529,10 +516,7 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states=cat_hidden_states, hidden_states=cat_hidden_states,
encoder_hidden_states=None, encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single, threshold=self.residual_diff_threshold_single,
parallelized=( parallelized=(self.transformer is not None and getattr(self.transformer, "_is_parallelized", False)),
self.transformer is not None
and getattr(self.transformer, "_is_parallelized", False)
),
mode="single", mode="single",
verbose=self.verbose, verbose=self.verbose,
call_remaining_fn=call_remaining_fn_single, call_remaining_fn=call_remaining_fn_single,
......
...@@ -47,17 +47,23 @@ class NunchakuModelLoaderMixin: ...@@ -47,17 +47,23 @@ class NunchakuModelLoaderMixin:
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
) )
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
config, _, _ = cls.load_config( config, _, _ = cls.load_config(
pretrained_model_name_or_path, pretrained_model_name_or_path,
subfolder=subfolder, subfolder=subfolder,
cache_dir=kwargs.get("cache_dir", None), cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
return_commit_hash=True, return_commit_hash=True,
force_download=kwargs.get("force_download", False), force_download=force_download,
proxies=kwargs.get("proxies", None), proxies=proxies,
local_files_only=kwargs.get("local_files_only", None), local_files_only=local_files_only,
token=kwargs.get("token", None), token=token,
revision=kwargs.get("revision", None), revision=revision,
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"}, user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs, **kwargs,
) )
......
...@@ -9,12 +9,13 @@ from .utils import run_test ...@@ -9,12 +9,13 @@ from .utils import run_test
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144), (True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),], (True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),
],
) )
def test_flux_dev_cache( def test_flux_dev_cache(
use_double_fb_cache: bool, use_double_fb_cache: bool,
residual_diff_threshold_multi : float, residual_diff_threshold_multi: float,
residual_diff_threshold_single : float, residual_diff_threshold_single: float,
height: int, height: int,
width: int, width: int,
num_inference_steps: int, num_inference_steps: int,
......
...@@ -9,7 +9,7 @@ from .utils import run_test ...@@ -9,7 +9,7 @@ from .utils import run_test
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146), (1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133), (2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.156),
], ],
) )
def test_flux_dev( def test_flux_dev(
......
...@@ -8,8 +8,8 @@ from .utils import run_test ...@@ -8,8 +8,8 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113), (1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113), (1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138), (1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120), (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
], ],
......
import pytest import pytest
from .utils import run_test
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
......
...@@ -12,8 +12,8 @@ from tqdm import tqdm ...@@ -12,8 +12,8 @@ from tqdm import tqdm
import nunchaku import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.lora.flux.compose import compose_lora
from ..data import get_dataset from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int from ..utils import already_generate, compute_lpips, hash_str_to_int
...@@ -143,8 +143,8 @@ def run_test( ...@@ -143,8 +143,8 @@ def run_test(
cpu_offload: bool = False, cpu_offload: bool = False,
cache_threshold: float = 0, cache_threshold: float = 0,
use_double_fb_cache: bool = False, use_double_fb_cache: bool = False,
residual_diff_threshold_multi : float = 0, residual_diff_threshold_multi: float = 0,
residual_diff_threshold_single : float = 0, residual_diff_threshold_single: float = 0,
lora_names: str | list[str] | None = None, lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0, lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 4, max_dataset_size: int = 4,
...@@ -319,7 +319,8 @@ def run_test( ...@@ -319,7 +319,8 @@ def run_test(
pipeline, pipeline,
use_double_fb_cache=use_double_fb_cache, use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi, residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single) residual_diff_threshold_single=residual_diff_threshold_single,
)
run_pipeline( run_pipeline(
batch_size=batch_size, batch_size=batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment