Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
3ecf0d25
Commit
3ecf0d25
authored
Apr 11, 2025
by
muyangli
Browse files
update
parent
e0ffd99d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
nunchaku/test.py
nunchaku/test.py
+5
-5
tests/flux/utils.py
tests/flux/utils.py
+3
-3
No files found.
nunchaku/test.py
View file @
3ecf0d25
...
...
@@ -2,17 +2,17 @@ import torch
from
diffusers
import
FluxPipeline
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
,
is_turing
if
__name__
==
"__main__"
:
capability
=
torch
.
cuda
.
get_device_capability
(
0
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
precision
=
get_precision
()
torch_dtype
=
torch
.
float16
if
is_turing
()
else
torch
.
bfloat16
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
offload
=
True
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
torch_dtype
=
torch_dtype
,
offload
=
True
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
_dtype
)
pipeline
.
enable_sequential_cpu_offload
()
image
=
pipeline
(
...
...
tests/flux/utils.py
View file @
3ecf0d25
...
...
@@ -61,9 +61,7 @@ def run_pipeline(dataset, batch_size: int, task: str, pipeline: FluxPipeline, sa
assert
task
in
[
"t2i"
,
"fill"
]
processor
=
None
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
)
for
row
in
tqdm
(
dataloader
):
for
row
in
tqdm
(
dataset
.
iter
(
batch_size
=
batch_size
,
drop_last_batch
=
False
)):
filenames
=
row
[
"filename"
]
prompts
=
row
[
"prompt"
]
...
...
@@ -234,6 +232,8 @@ def run_test(
precision_str
+=
f
"-cache
{
cache_threshold
}
"
if
i2f_mode
is
not
None
:
precision_str
+=
f
"-i2f
{
i2f_mode
}
"
if
batch_size
>
1
:
precision_str
+=
f
"-bs
{
batch_size
}
"
save_dir_4bit
=
os
.
path
.
join
(
"test_results"
,
dtype_str
,
precision_str
,
model_name
,
folder_name
)
if
not
already_generate
(
save_dir_4bit
,
max_dataset_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment