Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
2ede5f01
Commit
2ede5f01
authored
Apr 03, 2025
by
Muyang Li
Committed by
Zhekai Zhang
Apr 04, 2025
Browse files
Clean some codes and refract the tests
parent
83b7542d
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
1 deletion
+6
-1
tests/requirements.txt
tests/requirements.txt
+1
-0
tests/sana/test_t2i.py
tests/sana/test_t2i.py
+4
-0
tests/utils.py
tests/utils.py
+1
-1
No files found.
tests/requirements.txt
View file @
2ede5f01
# additional requirements for testing
pytest
pytest
datasets
datasets
torchmetrics
torchmetrics
...
...
tests/sana/test_t2i.py
View file @
2ede5f01
import
pytest
import
torch
import
torch
from
diffusers
import
SanaPAGPipeline
,
SanaPipeline
from
diffusers
import
SanaPAGPipeline
,
SanaPipeline
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku
import
NunchakuSanaTransformer2DModel
from
nunchaku.utils
import
get_precision
,
is_turing
@
pytest
.
mark
.
skipif
(
is_turing
()
or
get_precision
()
==
"fp4"
,
reason
=
"Skip tests due to Turing GPUs"
)
def
test_sana
():
def
test_sana
():
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
)
pipe
=
SanaPipeline
.
from_pretrained
(
pipe
=
SanaPipeline
.
from_pretrained
(
...
@@ -28,6 +31,7 @@ def test_sana():
...
@@ -28,6 +31,7 @@ def test_sana():
image
.
save
(
"sana_1600m.png"
)
image
.
save
(
"sana_1600m.png"
)
@
pytest
.
mark
.
skipif
(
is_turing
()
or
get_precision
()
==
"fp4"
,
reason
=
"Skip tests due to Turing GPUs"
)
def
test_sana_pag
():
def
test_sana_pag
():
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
,
pag_layers
=
8
)
transformer
=
NunchakuSanaTransformer2DModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-sana-1600m"
,
pag_layers
=
8
)
pipe
=
SanaPAGPipeline
.
from_pretrained
(
pipe
=
SanaPAGPipeline
.
from_pretrained
(
...
...
tests/utils.py
View file @
2ede5f01
...
@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset):
...
@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset):
def
compute_lpips
(
def
compute_lpips
(
ref_dirpath
:
str
,
gen_dirpath
:
str
,
batch_size
:
int
=
6
4
,
num_workers
:
int
=
8
,
device
:
str
|
torch
.
device
=
"cuda"
ref_dirpath
:
str
,
gen_dirpath
:
str
,
batch_size
:
int
=
4
,
num_workers
:
int
=
8
,
device
:
str
|
torch
.
device
=
"cuda"
)
->
float
:
)
->
float
:
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
metric
=
LearnedPerceptualImagePatchSimilarity
(
normalize
=
True
).
to
(
device
)
metric
=
LearnedPerceptualImagePatchSimilarity
(
normalize
=
True
).
to
(
device
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment