Unverified Commit f82dc8fb authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[Dont Sync] release: pre-release v0.3.0dev0 (#309)

parents 6564cf70 8b1ca5f6
from typing import Any, Callable
import torch
import torchvision.utils
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange
from peft.tuners import lora
......@@ -9,6 +8,8 @@ from PIL import Image
from torch import nn
from torchvision.transforms import functional as F
from nunchaku.utils import load_state_dict_in_safetensors
class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None:
......@@ -55,7 +56,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else:
assert svdq_lora_path is not None
self.transformer.update_lora_params(svdq_lora_path)
sd = load_state_dict_in_safetensors(svdq_lora_path)
sd = {k: v for k, v in sd.items() if not k.startswith("transformer.")}
self.transformer.update_lora_params(sd)
self.update_alpha(alpha)
@torch.no_grad()
......
......@@ -82,15 +82,16 @@ class NunchakuFluxTransformerBlocks(nn.Module):
image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = (
torch.stack(controlnet_block_samples).to(self.device) if len(controlnet_block_samples) > 0 else None
)
if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0:
controlnet_single_block_samples = (
torch.stack(controlnet_single_block_samples).to(self.device)
if len(controlnet_single_block_samples) > 0
else None
)
if len(controlnet_block_samples) > 0:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
else:
controlnet_block_samples = None
if controlnet_single_block_samples is not None:
if len(controlnet_single_block_samples) > 0:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
else:
controlnet_single_block_samples = None
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
......
......@@ -8,7 +8,7 @@ from .utils import run_test
@pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(0.12, 1024, 1024, 30, None, 1, 0.212),
(0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144),
],
)
def test_flux_dev_cache(
......
......@@ -8,8 +8,8 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, 50, "flashattn2", False, 0.139),
(2048, 512, 25, "nunchaku-fp16", False, 0.168),
(1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133),
],
)
def test_flux_dev(
......
......@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, True, 0.136),
(25, "realism", 0.9, True, 0.136 if get_precision() == "int4" else 0.112),
# (25, "ghibsky", 1, False, 0.186),
# (28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.291),
(24, "sketch", 1, True, 0.291 if get_precision() == "int4" else 0.182),
# (28, "yarn", 1, False, 0.211),
# (25, "haunted_linework", 1, True, 0.317),
],
......@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0,
expected_lpips=0.310,
expected_lpips=0.310 if get_precision() == "int4" else 0.168,
)
......@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, "flashattn2", False, 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.158),
(2048, 2048, "nunchaku-fp16", True, 0.166),
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113),
(1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
],
)
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
......
......@@ -20,7 +20,7 @@ def test_flux_canny_dev():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.076 if get_precision() == "int4" else 0.164,
expected_lpips=0.076 if get_precision() == "int4" else 0.090,
)
......@@ -39,7 +39,7 @@ def test_flux_depth_dev():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.137 if get_precision() == "int4" else 0.120,
expected_lpips=0.137 if get_precision() == "int4" else 0.092,
)
......@@ -58,7 +58,7 @@ def test_flux_fill_dev():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.046,
expected_lpips=0.046 if get_precision() == "int4" else 0.021,
)
......@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora():
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.181,
expected_lpips=0.181 if get_precision() == "int4" else 0.196,
)
......@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo():
cache_threshold=0,
lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.036,
expected_lpips=0.036 if get_precision() == "int4" else 0.030,
)
......@@ -140,5 +140,5 @@ def test_flux_dev_redux():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=(0.162 if get_precision() == "int4" else 0.5), # not sure why the fp4 model is so different
expected_lpips=(0.162 if get_precision() == "int4" else 0.466), # not sure why the fp4 model is so different
)
......@@ -9,8 +9,8 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[
(1024, 1024, "nunchaku-fp16", False, 0.140, 2),
(1920, 1080, "flashattn2", False, 0.160, 4),
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.118, 2),
(1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
],
)
def test_int4_schnell(
......
......@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.209)]
"height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "nunchaku-fp16", False, 0.209 if get_precision() == "int4" else 0.148)],
)
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment