"vscode:/vscode.git/clone" did not exist on "6c64741933c64df276b5ede21f62777dbe079cfd"
Unverified Commit f82dc8fb authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[Dont Sync] release: pre-release v0.3.0dev0 (#309)

parents 6564cf70 8b1ca5f6
from typing import Any, Callable from typing import Any, Callable
import torch import torch
import torchvision.utils
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange from einops import rearrange
from peft.tuners import lora from peft.tuners import lora
...@@ -9,6 +8,8 @@ from PIL import Image ...@@ -9,6 +8,8 @@ from PIL import Image
from torch import nn from torch import nn
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from nunchaku.utils import load_state_dict_in_safetensors
class FluxPix2pixTurboPipeline(FluxPipeline): class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None: def update_alpha(self, alpha: float) -> None:
...@@ -55,7 +56,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -55,7 +56,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self.load_lora_into_transformer(state_dict, {}, transformer=transformer) self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else: else:
assert svdq_lora_path is not None assert svdq_lora_path is not None
self.transformer.update_lora_params(svdq_lora_path) sd = load_state_dict_in_safetensors(svdq_lora_path)
sd = {k: v for k, v in sd.items() if not k.startswith("transformer.")}
self.transformer.update_lora_params(sd)
self.update_alpha(alpha) self.update_alpha(alpha)
@torch.no_grad() @torch.no_grad()
......
...@@ -82,15 +82,16 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -82,15 +82,16 @@ class NunchakuFluxTransformerBlocks(nn.Module):
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None: if controlnet_block_samples is not None:
controlnet_block_samples = ( if len(controlnet_block_samples) > 0:
torch.stack(controlnet_block_samples).to(self.device) if len(controlnet_block_samples) > 0 else None controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
) else:
if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0: controlnet_block_samples = None
controlnet_single_block_samples = (
torch.stack(controlnet_single_block_samples).to(self.device) if controlnet_single_block_samples is not None:
if len(controlnet_single_block_samples) > 0 if len(controlnet_single_block_samples) > 0:
else None controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
) else:
controlnet_single_block_samples = None
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
......
...@@ -8,7 +8,7 @@ from .utils import run_test ...@@ -8,7 +8,7 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
(0.12, 1024, 1024, 30, None, 1, 0.212), (0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144),
], ],
) )
def test_flux_dev_cache( def test_flux_dev_cache(
......
...@@ -8,8 +8,8 @@ from .utils import run_test ...@@ -8,8 +8,8 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, 50, "flashattn2", False, 0.139), (1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168), (2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133),
], ],
) )
def test_flux_dev( def test_flux_dev(
......
...@@ -8,10 +8,10 @@ from .utils import run_test ...@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips", "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[ [
(25, "realism", 0.9, True, 0.136), (25, "realism", 0.9, True, 0.136 if get_precision() == "int4" else 0.112),
# (25, "ghibsky", 1, False, 0.186), # (25, "ghibsky", 1, False, 0.186),
# (28, "anime", 1, False, 0.284), # (28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.291), (24, "sketch", 1, True, 0.291 if get_precision() == "int4" else 0.182),
# (28, "yarn", 1, False, 0.211), # (28, "yarn", 1, False, 0.211),
# (25, "haunted_linework", 1, True, 0.317), # (25, "haunted_linework", 1, True, 0.317),
], ],
...@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024(): ...@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"], lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1], lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0, cache_threshold=0,
expected_lpips=0.310, expected_lpips=0.310 if get_precision() == "int4" else 0.168,
) )
...@@ -8,10 +8,10 @@ from .utils import run_test ...@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, "flashattn2", False, 0.126), (1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113),
(1024, 1024, "nunchaku-fp16", False, 0.126), (1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113),
(1920, 1080, "nunchaku-fp16", False, 0.158), (1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166), (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
], ],
) )
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
......
...@@ -20,7 +20,7 @@ def test_flux_canny_dev(): ...@@ -20,7 +20,7 @@ def test_flux_canny_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.076 if get_precision() == "int4" else 0.164, expected_lpips=0.076 if get_precision() == "int4" else 0.090,
) )
...@@ -39,7 +39,7 @@ def test_flux_depth_dev(): ...@@ -39,7 +39,7 @@ def test_flux_depth_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.137 if get_precision() == "int4" else 0.120, expected_lpips=0.137 if get_precision() == "int4" else 0.092,
) )
...@@ -58,7 +58,7 @@ def test_flux_fill_dev(): ...@@ -58,7 +58,7 @@ def test_flux_fill_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.046, expected_lpips=0.046 if get_precision() == "int4" else 0.021,
) )
...@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora(): ...@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora():
cache_threshold=0, cache_threshold=0,
lora_names="depth", lora_names="depth",
lora_strengths=0.85, lora_strengths=0.85,
expected_lpips=0.181, expected_lpips=0.181 if get_precision() == "int4" else 0.196,
) )
...@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo(): ...@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo():
cache_threshold=0, cache_threshold=0,
lora_names="turbo8", lora_names="turbo8",
lora_strengths=1, lora_strengths=1,
expected_lpips=0.036, expected_lpips=0.036 if get_precision() == "int4" else 0.030,
) )
...@@ -140,5 +140,5 @@ def test_flux_dev_redux(): ...@@ -140,5 +140,5 @@ def test_flux_dev_redux():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=(0.162 if get_precision() == "int4" else 0.5), # not sure why the fp4 model is so different expected_lpips=(0.162 if get_precision() == "int4" else 0.466), # not sure why the fp4 model is so different
) )
...@@ -9,8 +9,8 @@ from .utils import run_test ...@@ -9,8 +9,8 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size", "height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[ [
(1024, 1024, "nunchaku-fp16", False, 0.140, 2), (1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.118, 2),
(1920, 1080, "flashattn2", False, 0.160, 4), (1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
], ],
) )
def test_int4_schnell( def test_int4_schnell(
......
...@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing ...@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.209)] "height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "nunchaku-fp16", False, 0.209 if get_precision() == "int4" else 0.148)],
) )
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test( run_test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment