Commit 85959824 authored by Muyang Li's avatar Muyang Li Committed by muyangli
Browse files

[minor] add fp4 test expected lpips

parent 45e055ce
......@@ -8,7 +8,7 @@ from .utils import run_test
@pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(0.12, 1024, 1024, 30, None, 1, 0.212),
(0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144),
],
)
def test_flux_dev_cache(
......
......@@ -8,8 +8,8 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, 50, "flashattn2", False, 0.139),
(2048, 512, 25, "nunchaku-fp16", False, 0.168),
(1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133),
],
)
def test_flux_dev(
......
......@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, True, 0.136),
(25, "realism", 0.9, True, 0.136 if get_precision() == "int4" else 0.1),
# (25, "ghibsky", 1, False, 0.186),
# (28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.291),
(24, "sketch", 1, True, 0.291 if get_precision() == "int4" else 0.182),
# (28, "yarn", 1, False, 0.211),
# (25, "haunted_linework", 1, True, 0.317),
],
......@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0,
expected_lpips=0.310,
expected_lpips=0.310 if get_precision() == "int4" else 0.150,
)
......@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, "flashattn2", False, 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.158),
(2048, 2048, "nunchaku-fp16", True, 0.166),
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113),
(1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
],
)
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
......
......@@ -20,7 +20,7 @@ def test_flux_canny_dev():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.076 if get_precision() == "int4" else 0.164,
expected_lpips=0.076 if get_precision() == "int4" else 0.090,
)
......@@ -39,7 +39,7 @@ def test_flux_depth_dev():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.137 if get_precision() == "int4" else 0.120,
expected_lpips=0.137 if get_precision() == "int4" else 0.092,
)
......@@ -58,7 +58,7 @@ def test_flux_fill_dev():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.046,
expected_lpips=0.046 if get_precision() == "int4" else 0.021,
)
......@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora():
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.181,
expected_lpips=0.181 if get_precision() == "int4" else 0.196,
)
......@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo():
cache_threshold=0,
lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.036,
expected_lpips=0.036 if get_precision() == "int4" else 0.030,
)
......@@ -140,5 +140,5 @@ def test_flux_dev_redux():
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=(0.162 if get_precision() == "int4" else 0.5), # not sure why the fp4 model is so different
expected_lpips=(0.162 if get_precision() == "int4" else 0.466), # not sure why the fp4 model is so different
)
......@@ -9,8 +9,8 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[
(1024, 1024, "nunchaku-fp16", False, 0.140, 2),
(1920, 1080, "flashattn2", False, 0.160, 4),
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.118, 2),
(1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
],
)
def test_int4_schnell(
......
......@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.209)]
"height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "nunchaku-fp16", False, 0.209 if get_precision() == "int4" else 0.148)],
)
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment