Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
85959824
Commit
85959824
authored
Apr 20, 2025
by
Muyang Li
Committed by
muyangli
Apr 20, 2025
Browse files
[minor] add fp4 test expected lpips
parent
45e055ce
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
20 additions
and
19 deletions
+20
-19
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+1
-1
tests/flux/test_flux_dev.py
tests/flux/test_flux_dev.py
+2
-2
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+3
-3
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+4
-4
tests/flux/test_flux_tools.py
tests/flux/test_flux_tools.py
+6
-6
tests/flux/test_multiple_batch.py
tests/flux/test_multiple_batch.py
+2
-2
tests/flux/test_shuttle_jaguar.py
tests/flux/test_shuttle_jaguar.py
+2
-1
No files found.
tests/flux/test_flux_cache.py
View file @
85959824
...
@@ -8,7 +8,7 @@ from .utils import run_test
...
@@ -8,7 +8,7 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
[
[
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.212
),
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.212
if
get_precision
()
==
"int4"
else
0.144
),
],
],
)
)
def
test_flux_dev_cache
(
def
test_flux_dev_cache
(
...
...
tests/flux/test_flux_dev.py
View file @
85959824
...
@@ -8,8 +8,8 @@ from .utils import run_test
...
@@ -8,8 +8,8 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
[
[
(
1024
,
1024
,
50
,
"flashattn2"
,
False
,
0.139
),
(
1024
,
1024
,
50
,
"flashattn2"
,
False
,
0.139
if
get_precision
()
==
"int4"
else
0.146
),
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.168
),
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.168
if
get_precision
()
==
"int4"
else
0.133
),
],
],
)
)
def
test_flux_dev
(
def
test_flux_dev
(
...
...
tests/flux/test_flux_dev_loras.py
View file @
85959824
...
@@ -8,10 +8,10 @@ from .utils import run_test
...
@@ -8,10 +8,10 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips"
,
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips"
,
[
[
(
25
,
"realism"
,
0.9
,
True
,
0.136
),
(
25
,
"realism"
,
0.9
,
True
,
0.136
if
get_precision
()
==
"int4"
else
0.1
),
# (25, "ghibsky", 1, False, 0.186),
# (25, "ghibsky", 1, False, 0.186),
# (28, "anime", 1, False, 0.284),
# (28, "anime", 1, False, 0.284),
(
24
,
"sketch"
,
1
,
True
,
0.291
),
(
24
,
"sketch"
,
1
,
True
,
0.291
if
get_precision
()
==
"int4"
else
0.182
),
# (28, "yarn", 1, False, 0.211),
# (28, "yarn", 1, False, 0.211),
# (25, "haunted_linework", 1, True, 0.317),
# (25, "haunted_linework", 1, True, 0.317),
],
],
...
@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
...
@@ -51,5 +51,5 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
lora_names
=
[
"realism"
,
"ghibsky"
,
"anime"
,
"sketch"
,
"yarn"
,
"haunted_linework"
,
"turbo8"
],
lora_names
=
[
"realism"
,
"ghibsky"
,
"anime"
,
"sketch"
,
"yarn"
,
"haunted_linework"
,
"turbo8"
],
lora_strengths
=
[
0
,
1
,
0
,
0
,
0
,
0
,
1
],
lora_strengths
=
[
0
,
1
,
0
,
0
,
0
,
0
,
1
],
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.310
,
expected_lpips
=
0.310
if
get_precision
()
==
"int4"
else
0.150
,
)
)
tests/flux/test_flux_schnell.py
View file @
85959824
...
@@ -8,10 +8,10 @@ from .utils import run_test
...
@@ -8,10 +8,10 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips"
,
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[
[
(
1024
,
1024
,
"flashattn2"
,
False
,
0.126
),
(
1024
,
1024
,
"flashattn2"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.113
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.126
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.113
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.158
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.158
if
get_precision
()
==
"int4"
else
0.138
),
(
2048
,
2048
,
"nunchaku-fp16"
,
True
,
0.166
),
(
2048
,
2048
,
"nunchaku-fp16"
,
True
,
0.166
if
get_precision
()
==
"int4"
else
0.120
),
],
],
)
)
def
test_int4_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
def
test_int4_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
...
...
tests/flux/test_flux_tools.py
View file @
85959824
...
@@ -20,7 +20,7 @@ def test_flux_canny_dev():
...
@@ -20,7 +20,7 @@ def test_flux_canny_dev():
attention_impl
=
"nunchaku-fp16"
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cpu_offload
=
False
,
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.076
if
get_precision
()
==
"int4"
else
0.
164
,
expected_lpips
=
0.076
if
get_precision
()
==
"int4"
else
0.
090
,
)
)
...
@@ -39,7 +39,7 @@ def test_flux_depth_dev():
...
@@ -39,7 +39,7 @@ def test_flux_depth_dev():
attention_impl
=
"nunchaku-fp16"
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cpu_offload
=
False
,
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.137
if
get_precision
()
==
"int4"
else
0.
120
,
expected_lpips
=
0.137
if
get_precision
()
==
"int4"
else
0.
092
,
)
)
...
@@ -58,7 +58,7 @@ def test_flux_fill_dev():
...
@@ -58,7 +58,7 @@ def test_flux_fill_dev():
attention_impl
=
"nunchaku-fp16"
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cpu_offload
=
False
,
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.046
,
expected_lpips
=
0.046
if
get_precision
()
==
"int4"
else
0.021
,
)
)
...
@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora():
...
@@ -100,7 +100,7 @@ def test_flux_dev_depth_lora():
cache_threshold
=
0
,
cache_threshold
=
0
,
lora_names
=
"depth"
,
lora_names
=
"depth"
,
lora_strengths
=
0.85
,
lora_strengths
=
0.85
,
expected_lpips
=
0.181
,
expected_lpips
=
0.181
if
get_precision
()
==
"int4"
else
0.196
,
)
)
...
@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo():
...
@@ -121,7 +121,7 @@ def test_flux_fill_dev_turbo():
cache_threshold
=
0
,
cache_threshold
=
0
,
lora_names
=
"turbo8"
,
lora_names
=
"turbo8"
,
lora_strengths
=
1
,
lora_strengths
=
1
,
expected_lpips
=
0.036
,
expected_lpips
=
0.036
if
get_precision
()
==
"int4"
else
0.030
,
)
)
...
@@ -140,5 +140,5 @@ def test_flux_dev_redux():
...
@@ -140,5 +140,5 @@ def test_flux_dev_redux():
attention_impl
=
"nunchaku-fp16"
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cpu_offload
=
False
,
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
(
0.162
if
get_precision
()
==
"int4"
else
0.
5
),
# not sure why the fp4 model is so different
expected_lpips
=
(
0.162
if
get_precision
()
==
"int4"
else
0.
466
),
# not sure why the fp4 model is so different
)
)
tests/flux/test_multiple_batch.py
View file @
85959824
...
@@ -9,8 +9,8 @@ from .utils import run_test
...
@@ -9,8 +9,8 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size"
,
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size"
,
[
[
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.140
,
2
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.140
if
get_precision
()
==
"int4"
else
0.118
,
2
),
(
1920
,
1080
,
"flashattn2"
,
False
,
0.160
,
4
),
(
1920
,
1080
,
"flashattn2"
,
False
,
0.160
if
get_precision
()
==
"int4"
else
0.123
,
4
),
],
],
)
)
def
test_int4_schnell
(
def
test_int4_schnell
(
...
...
tests/flux/test_shuttle_jaguar.py
View file @
85959824
...
@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -6,7 +6,8 @@ from nunchaku.utils import get_precision, is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.209
)]
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.209
if
get_precision
()
==
"int4"
else
0.148
)],
)
)
def
test_shuttle_jaguar
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
def
test_shuttle_jaguar
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
run_test
(
run_test
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment