Commit 85959824 authored by Muyang Li's avatar Muyang Li Committed by muyangli
Browse files

[minor] add fp4 test expected lpips

parent 45e055ce
...@@ -8,7 +8,7 @@ from .utils import run_test ...@@ -8,7 +8,7 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
(0.12, 1024, 1024, 30, None, 1, 0.212), (0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144),
], ],
) )
def test_flux_dev_cache( def test_flux_dev_cache(
......
...@@ -8,8 +8,8 @@ from .utils import run_test ...@@ -8,8 +8,8 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, 50, "flashattn2", False, 0.139), (1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168), (2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133),
], ],
) )
def test_flux_dev( def test_flux_dev(
......
...@@ -8,10 +8,10 @@ from .utils import run_test ...@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips", "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[ [
(25, "realism", 0.9, True, 0.136), (25, "realism", 0.9, True, 0.136 if get_precision() == "int4" else 0.1),
# (25, "ghibsky", 1, False, 0.186), # (25, "ghibsky", 1, False, 0.186),
# (28, "anime", 1, False, 0.284), # (28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.291), (24, "sketch", 1, True, 0.291 if get_precision() == "int4" else 0.182),
# (28, "yarn", 1, False, 0.211), # (28, "yarn", 1, False, 0.211),
# (25, "haunted_linework", 1, True, 0.317), # (25, "haunted_linework", 1, True, 0.317),
], ],
...@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024(): ...@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"], lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1], lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0, cache_threshold=0,
expected_lpips=0.310, expected_lpips=0.310 if get_precision() == "int4" else 0.150,
) )
...@@ -8,10 +8,10 @@ from .utils import run_test ...@@ -8,10 +8,10 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, "flashattn2", False, 0.126), (1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113),
(1024, 1024, "nunchaku-fp16", False, 0.126), (1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113),
(1920, 1080, "nunchaku-fp16", False, 0.158), (1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166), (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
], ],
) )
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
......
...@@ -20,7 +20,7 @@ def test_flux_canny_dev(): ...@@ -20,7 +20,7 @@ def test_flux_canny_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.076 if get_precision() == "int4" else 0.164, expected_lpips=0.076 if get_precision() == "int4" else 0.090,
) )
...@@ -39,7 +39,7 @@ def test_flux_depth_dev(): ...@@ -39,7 +39,7 @@ def test_flux_depth_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.137 if get_precision() == "int4" else 0.120, expected_lpips=0.137 if get_precision() == "int4" else 0.092,
) )
...@@ -58,7 +58,7 @@ def test_flux_fill_dev(): ...@@ -58,7 +58,7 @@ def test_flux_fill_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.046, expected_lpips=0.046 if get_precision() == "int4" else 0.021,
) )
...@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora(): ...@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora():
cache_threshold=0, cache_threshold=0,
lora_names="depth", lora_names="depth",
lora_strengths=0.85, lora_strengths=0.85,
expected_lpips=0.181, expected_lpips=0.181 if get_precision() == "int4" else 0.196,
) )
...@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo(): ...@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo():
cache_threshold=0, cache_threshold=0,
lora_names="turbo8", lora_names="turbo8",
lora_strengths=1, lora_strengths=1,
expected_lpips=0.036, expected_lpips=0.036 if get_precision() == "int4" else 0.030,
) )
...@@ -140,5 +140,5 @@ def test_flux_dev_redux(): ...@@ -140,5 +140,5 @@ def test_flux_dev_redux():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=(0.162 if get_precision() == "int4" else 0.5), # not sure why the fp4 model is so different expected_lpips=(0.162 if get_precision() == "int4" else 0.466), # not sure why the fp4 model is so different
) )
...@@ -9,8 +9,8 @@ from .utils import run_test ...@@ -9,8 +9,8 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size", "height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[ [
(1024, 1024, "nunchaku-fp16", False, 0.140, 2), (1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.118, 2),
(1920, 1080, "flashattn2", False, 0.160, 4), (1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
], ],
) )
def test_int4_schnell( def test_int4_schnell(
......
...@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing ...@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.209)] "height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "nunchaku-fp16", False, 0.209 if get_precision() == "int4" else 0.148)],
) )
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test( run_test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment