Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
748be0ab
"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c4ba257c644daee9d3c906339826216afbe605bf"
Commit
748be0ab
authored
Apr 11, 2025
by
muyangli
Browse files
cleaning some tests
parent
102a0f7d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
27 additions
and
33 deletions
+27
-33
tests/data/__init__.py
tests/data/__init__.py
+1
-6
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+0
-1
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+9
-8
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+1
-2
tests/flux/test_flux_tools.py
tests/flux/test_flux_tools.py
+1
-1
tests/flux/test_shuttle_jaguar.py
tests/flux/test_shuttle_jaguar.py
+1
-2
tests/flux/test_turing.py
tests/flux/test_turing.py
+3
-5
tests/flux/utils.py
tests/flux/utils.py
+11
-8
No files found.
tests/data/__init__.py
View file @
748be0ab
...
@@ -7,12 +7,7 @@ from huggingface_hub import snapshot_download
...
@@ -7,12 +7,7 @@ from huggingface_hub import snapshot_download
from
nunchaku.utils
import
fetch_or_download
from
nunchaku.utils
import
fetch_or_download
__all__
=
[
"get_dataset"
,
"load_dataset_yaml"
,
"download_hf_dataset"
]
__all__
=
[
"get_dataset"
,
"load_dataset_yaml"
]
def
download_hf_dataset
(
repo_id
:
str
=
"mit-han-lab/nunchaku-test"
,
local_dir
:
str
|
None
=
None
)
->
str
:
path
=
snapshot_download
(
repo_id
=
repo_id
,
repo_type
=
"dataset"
,
local_dir
=
local_dir
)
return
path
def
load_dataset_yaml
(
meta_path
:
str
,
max_dataset_size
:
int
=
-
1
,
repeat
:
int
=
4
)
->
dict
:
def
load_dataset_yaml
(
meta_path
:
str
,
max_dataset_size
:
int
=
-
1
,
repeat
:
int
=
4
)
->
dict
:
...
...
tests/flux/test_flux_cache.py
View file @
748be0ab
...
@@ -9,7 +9,6 @@ from .utils import run_test
...
@@ -9,7 +9,6 @@ from .utils import run_test
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
[
[
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.26
),
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.26
),
(
0.12
,
512
,
2048
,
30
,
"anime"
,
1
,
0.4
),
],
],
)
)
def
test_flux_dev_loras
(
def
test_flux_dev_loras
(
...
...
tests/flux/test_flux_dev_loras.py
View file @
748be0ab
...
@@ -10,10 +10,10 @@ from .utils import run_test
...
@@ -10,10 +10,10 @@ from .utils import run_test
[
[
(
25
,
"realism"
,
0.9
,
True
,
0.178
),
(
25
,
"realism"
,
0.9
,
True
,
0.178
),
(
25
,
"ghibsky"
,
1
,
False
,
0.164
),
(
25
,
"ghibsky"
,
1
,
False
,
0.164
),
(
28
,
"anime"
,
1
,
False
,
0.284
),
#
(28, "anime", 1, False, 0.284),
(
24
,
"sketch"
,
1
,
True
,
0.223
),
(
24
,
"sketch"
,
1
,
True
,
0.223
),
(
28
,
"yarn"
,
1
,
False
,
0.211
),
#
(28, "yarn", 1, False, 0.211),
(
25
,
"haunted_linework"
,
1
,
True
,
0.317
),
#
(25, "haunted_linework", 1, True, 0.317),
],
],
)
)
def
test_flux_dev_loras
(
num_inference_steps
,
lora_name
,
lora_strength
,
cpu_offload
,
expected_lpips
):
def
test_flux_dev_loras
(
num_inference_steps
,
lora_name
,
lora_strength
,
cpu_offload
,
expected_lpips
):
...
@@ -26,6 +26,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
...
@@ -26,6 +26,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
num_inference_steps
=
num_inference_steps
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
use_qencoder
=
False
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
cpu_offload
,
cpu_offload
=
cpu_offload
,
lora_names
=
lora_name
,
lora_names
=
lora_name
,
lora_strengths
=
lora_strength
,
lora_strengths
=
lora_strength
,
...
@@ -55,13 +56,13 @@ def test_flux_dev_hypersd8_1536x2048():
...
@@ -55,13 +56,13 @@ def test_flux_dev_hypersd8_1536x2048():
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
def
test_flux_dev_turbo8_
2048x2048
():
def
test_flux_dev_turbo8_
1024x1920
():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ"
,
dataset_name
=
"MJHQ"
,
height
=
2048
,
height
=
1024
,
width
=
20
48
,
width
=
19
20
,
num_inference_steps
=
8
,
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
use_qencoder
=
False
,
...
@@ -100,7 +101,7 @@ def test_flux_dev_turbo8_yarn_1024x1024():
...
@@ -100,7 +101,7 @@ def test_flux_dev_turbo8_yarn_1024x1024():
run_test
(
run_test
(
precision
=
get_precision
(),
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
model_name
=
"flux.1-dev"
,
dataset_name
=
"
ghibsky
"
,
dataset_name
=
"
haunted_linework
"
,
height
=
1024
,
height
=
1024
,
width
=
1024
,
width
=
1024
,
num_inference_steps
=
8
,
num_inference_steps
=
8
,
...
@@ -108,7 +109,7 @@ def test_flux_dev_turbo8_yarn_1024x1024():
...
@@ -108,7 +109,7 @@ def test_flux_dev_turbo8_yarn_1024x1024():
use_qencoder
=
False
,
use_qencoder
=
False
,
cpu_offload
=
True
,
cpu_offload
=
True
,
lora_names
=
[
"realism"
,
"ghibsky"
,
"anime"
,
"sketch"
,
"yarn"
,
"haunted_linework"
,
"turbo8"
],
lora_names
=
[
"realism"
,
"ghibsky"
,
"anime"
,
"sketch"
,
"yarn"
,
"haunted_linework"
,
"turbo8"
],
lora_strengths
=
[
0
,
1
,
0
,
0
,
0
,
0
,
1
],
lora_strengths
=
[
0
,
0
,
0
,
0
,
0
,
1
,
1
],
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.44
,
expected_lpips
=
0.44
,
)
)
tests/flux/test_flux_schnell.py
View file @
748be0ab
...
@@ -10,9 +10,8 @@ from .utils import run_test
...
@@ -10,9 +10,8 @@ from .utils import run_test
[
[
(
1024
,
1024
,
"flashattn2"
,
False
,
0.250
),
(
1024
,
1024
,
"flashattn2"
,
False
,
0.250
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.255
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.255
),
(
1024
,
1024
,
"flashattn2"
,
True
,
0.250
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.253
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.253
),
(
2048
,
2048
,
"
flashattn2
"
,
True
,
0.274
),
(
2048
,
2048
,
"
nunchaku-fp16
"
,
True
,
0.274
),
],
],
)
)
def
test_int4_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
def
test_int4_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
...
...
tests/flux/test_flux_tools.py
View file @
748be0ab
...
@@ -140,5 +140,5 @@ def test_flux_dev_redux():
...
@@ -140,5 +140,5 @@ def test_flux_dev_redux():
attention_impl
=
"nunchaku-fp16"
,
attention_impl
=
"nunchaku-fp16"
,
cpu_offload
=
False
,
cpu_offload
=
False
,
cache_threshold
=
0
,
cache_threshold
=
0
,
expected_lpips
=
0.198
if
get_precision
()
==
"int4"
else
0.
55
,
# redux seems to generate different images on 5090
expected_lpips
=
(
0.198
if
get_precision
()
==
"int4"
else
0.
198
),
)
)
tests/flux/test_shuttle_jaguar.py
View file @
748be0ab
...
@@ -6,8 +6,7 @@ from nunchaku.utils import get_precision, is_turing
...
@@ -6,8 +6,7 @@ from nunchaku.utils import get_precision, is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips"
,
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.25
)]
[(
1024
,
1024
,
"flashattn2"
,
False
,
0.25
),
(
2048
,
512
,
"nunchaku-fp16"
,
False
,
0.25
)],
)
)
def
test_shuttle_jaguar
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
def
test_shuttle_jaguar
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
run_test
(
run_test
(
...
...
tests/flux/test_turing.py
View file @
748be0ab
import
pytest
import
pytest
from
nunchaku.utils
import
get_precision
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
get_precision
()
==
"fp4"
,
reason
=
"Blackwell
GPUs. Skip tests
for Turing
."
)
@
pytest
.
mark
.
skipif
(
not
is_turing
(),
reason
=
"Not turing
GPUs. Skip tests."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips"
,
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips"
,
[
[
(
1024
,
1024
,
50
,
True
,
None
,
0.253
),
(
1024
,
1024
,
50
,
True
,
"enabled"
,
0.258
),
(
1024
,
1024
,
50
,
True
,
"enabled"
,
0.258
),
(
1024
,
1024
,
50
,
True
,
"always"
,
0.257
),
],
],
)
)
def
test_flux_dev
(
def
test_flux_dev
_on_turing
(
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
cpu_offload
:
bool
,
i2f_mode
:
str
|
None
,
expected_lpips
:
float
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
cpu_offload
:
bool
,
i2f_mode
:
str
|
None
,
expected_lpips
:
float
):
):
run_test
(
run_test
(
...
...
tests/flux/utils.py
View file @
748be0ab
...
@@ -10,7 +10,7 @@ from tqdm import tqdm
...
@@ -10,7 +10,7 @@ from tqdm import tqdm
import
nunchaku
import
nunchaku
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.lora.flux.compose
import
compose_lora
from
nunchaku.lora.flux.compose
import
compose_lora
from
..data
import
download_hf_dataset
,
get_dataset
from
..data
import
get_dataset
from
..utils
import
already_generate
,
compute_lpips
,
hash_str_to_int
from
..utils
import
already_generate
,
compute_lpips
,
hash_str_to_int
ORIGINAL_REPO_MAP
=
{
ORIGINAL_REPO_MAP
=
{
...
@@ -117,7 +117,7 @@ def run_test(
...
@@ -117,7 +117,7 @@ def run_test(
cache_threshold
:
float
=
0
,
cache_threshold
:
float
=
0
,
lora_names
:
str
|
list
[
str
]
|
None
=
None
,
lora_names
:
str
|
list
[
str
]
|
None
=
None
,
lora_strengths
:
float
|
list
[
float
]
=
1.0
,
lora_strengths
:
float
|
list
[
float
]
=
1.0
,
max_dataset_size
:
int
=
20
,
max_dataset_size
:
int
=
8
,
i2f_mode
:
str
|
None
=
None
,
i2f_mode
:
str
|
None
=
None
,
expected_lpips
:
float
=
0.5
,
expected_lpips
:
float
=
0.5
,
):
):
...
@@ -153,10 +153,7 @@ def run_test(
...
@@ -153,10 +153,7 @@ def run_test(
for
lora_name
,
lora_strength
in
zip
(
lora_names
,
lora_strengths
):
for
lora_name
,
lora_strength
in
zip
(
lora_names
,
lora_strengths
):
folder_name
+=
f
"-
{
lora_name
}
_
{
lora_strength
}
"
folder_name
+=
f
"-
{
lora_name
}
_
{
lora_strength
}
"
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
"test_results"
,
"ref"
)):
ref_root
=
os
.
path
.
join
(
"test_results"
,
"ref"
)
ref_root
=
download_hf_dataset
(
local_dir
=
os
.
path
.
join
(
"test_results"
,
"ref"
))
else
:
ref_root
=
os
.
path
.
join
(
"test_results"
,
"ref"
)
save_dir_16bit
=
os
.
path
.
join
(
ref_root
,
dtype_str
,
model_name
,
folder_name
)
save_dir_16bit
=
os
.
path
.
join
(
ref_root
,
dtype_str
,
model_name
,
folder_name
)
if
task
in
[
"t2i"
,
"redux"
]:
if
task
in
[
"t2i"
,
"redux"
]:
...
@@ -171,7 +168,13 @@ def run_test(
...
@@ -171,7 +168,13 @@ def run_test(
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{
"text_encoder"
:
None
,
"text_encoder2"
:
None
}
if
task
==
"redux"
else
{}
pipeline_init_kwargs
=
{
"text_encoder"
:
None
,
"text_encoder2"
:
None
}
if
task
==
"redux"
else
{}
pipeline
=
pipeline_cls
.
from_pretrained
(
model_id_16bit
,
torch_dtype
=
dtype
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline_cls
.
from_pretrained
(
model_id_16bit
,
torch_dtype
=
dtype
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
gpu_properties
=
torch
.
cuda
.
get_device_properties
(
0
)
gpu_memory
=
gpu_properties
.
total_memory
/
(
1024
**
2
)
if
gpu_memory
>
36
*
1024
:
pipeline
=
pipeline
.
to
(
"cuda"
)
else
:
pipeline
.
enable_sequential_cpu_offload
()
if
len
(
lora_names
)
>
0
:
if
len
(
lora_names
)
>
0
:
for
i
,
(
lora_name
,
lora_strength
)
in
enumerate
(
zip
(
lora_names
,
lora_strengths
)):
for
i
,
(
lora_name
,
lora_strength
)
in
enumerate
(
zip
(
lora_names
,
lora_strengths
)):
...
@@ -269,4 +272,4 @@ def run_test(
...
@@ -269,4 +272,4 @@ def run_test(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.
0
5
assert
lpips
<
expected_lpips
*
1.
2
5
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment