Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
30ba84c5
Commit
30ba84c5
authored
Apr 11, 2025
by
muyangli
Browse files
update tests
parent
748be0ab
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
40 additions
and
17 deletions
+40
-17
tests/flux/test_device_id.py
tests/flux/test_device_id.py
+23
-0
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+1
-1
tests/flux/test_flux_dev.py
tests/flux/test_flux_dev.py
+1
-1
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+5
-5
tests/flux/test_flux_memory.py
tests/flux/test_flux_memory.py
+1
-1
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+1
-1
tests/flux/test_flux_tools.py
tests/flux/test_flux_tools.py
+7
-7
tests/flux/test_shuttle_jaguar.py
tests/flux/test_shuttle_jaguar.py
+1
-1
No files found.
tests/flux/test_device_id.py
0 → 100644
View file @
30ba84c5
import
pytest
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
,
is_turing
@
pytest
.
mark
.
skipif
(
is_turing
()
or
torch
.
cuda
.
device_count
()
<=
1
,
reason
=
"Skip tests due to using Turing GPUs or single GPU"
)
def
test_device_id
():
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype
=
torch
.
float16
if
is_turing
(
"cuda:1"
)
else
torch
.
float32
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
torch_dtype
=
torch_dtype
,
device
=
"cuda:1"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch_dtype
).
to
(
"cuda:1"
)
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
)
tests/flux/test_flux_cache.py
View file @
30ba84c5
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from
.utils
import
run_test
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests
for
Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests
due to using
Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
[
[
...
...
tests/flux/test_flux_dev.py
View file @
30ba84c5
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from
.utils
import
run_test
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests
for
Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests
due to using
Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
[
[
...
...
tests/flux/test_flux_dev_loras.py
View file @
30ba84c5
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from
.utils
import
run_test
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips"
,
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips"
,
[
[
...
@@ -35,7 +35,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
...
@@ -35,7 +35,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_hypersd8_1536x2048
():
def
test_flux_dev_hypersd8_1536x2048
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -55,7 +55,7 @@ def test_flux_dev_hypersd8_1536x2048():
...
@@ -55,7 +55,7 @@ def test_flux_dev_hypersd8_1536x2048():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_turbo8_1024x1920
():
def
test_flux_dev_turbo8_1024x1920
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -76,7 +76,7 @@ def test_flux_dev_turbo8_1024x1920():
...
@@ -76,7 +76,7 @@ def test_flux_dev_turbo8_1024x1920():
# lora composition
# lora composition
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_turbo8_yarn_2048x1024
():
def
test_flux_dev_turbo8_yarn_2048x1024
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -96,7 +96,7 @@ def test_flux_dev_turbo8_yarn_2048x1024():
...
@@ -96,7 +96,7 @@ def test_flux_dev_turbo8_yarn_2048x1024():
# large rank loras
# large rank loras
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_turbo8_yarn_1024x1024
():
def
test_flux_dev_turbo8_yarn_1024x1024
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
...
tests/flux/test_flux_memory.py
View file @
30ba84c5
...
@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
...
@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from
nunchaku.utils
import
get_precision
,
is_turing
from
nunchaku.utils
import
get_precision
,
is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"use_qencoder,cpu_offload,memory_limit"
,
"use_qencoder,cpu_offload,memory_limit"
,
[
[
...
...
tests/flux/test_flux_schnell.py
View file @
30ba84c5
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from
.utils
import
run_test
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips"
,
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[
[
...
...
tests/flux/test_flux_tools.py
View file @
30ba84c5
...
@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing
from
.utils
import
run_test
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_canny_dev
():
def
test_flux_canny_dev
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -24,7 +24,7 @@ def test_flux_canny_dev():
...
@@ -24,7 +24,7 @@ def test_flux_canny_dev():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_depth_dev
():
def
test_flux_depth_dev
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -43,7 +43,7 @@ def test_flux_depth_dev():
...
@@ -43,7 +43,7 @@ def test_flux_depth_dev():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_fill_dev
():
def
test_flux_fill_dev
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -62,7 +62,7 @@ def test_flux_fill_dev():
...
@@ -62,7 +62,7 @@ def test_flux_fill_dev():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_canny_lora
():
def
test_flux_dev_canny_lora
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -83,7 +83,7 @@ def test_flux_dev_canny_lora():
...
@@ -83,7 +83,7 @@ def test_flux_dev_canny_lora():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_depth_lora
():
def
test_flux_dev_depth_lora
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -104,7 +104,7 @@ def test_flux_dev_depth_lora():
...
@@ -104,7 +104,7 @@ def test_flux_dev_depth_lora():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_fill_dev_turbo
():
def
test_flux_fill_dev_turbo
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
@@ -125,7 +125,7 @@ def test_flux_fill_dev_turbo():
...
@@ -125,7 +125,7 @@ def test_flux_fill_dev_turbo():
)
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
def
test_flux_dev_redux
():
def
test_flux_dev_redux
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
...
...
tests/flux/test_shuttle_jaguar.py
View file @
30ba84c5
...
@@ -4,7 +4,7 @@ from .utils import run_test
...
@@ -4,7 +4,7 @@ from .utils import run_test
from
nunchaku.utils
import
get_precision
,
is_turing
from
nunchaku.utils
import
get_precision
,
is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to
using
Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.25
)]
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.25
)]
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment