Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e0ffd99d
Commit
e0ffd99d
authored
Apr 11, 2025
by
muyangli
Browse files
add batch inference
parent
30ba84c5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
27 deletions
+48
-27
tests/flux/utils.py
tests/flux/utils.py
+48
-27
No files found.
tests/flux/utils.py
View file @
e0ffd99d
...
...
@@ -45,7 +45,7 @@ LORA_PATH_MAP = {
}
def
run_pipeline
(
dataset
,
task
:
str
,
pipeline
:
FluxPipeline
,
save_dir
:
str
,
forward_kwargs
:
dict
=
{}):
def
run_pipeline
(
dataset
,
batch_size
:
int
,
task
:
str
,
pipeline
:
FluxPipeline
,
save_dir
:
str
,
forward_kwargs
:
dict
=
{}):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
pipeline
.
set_progress_bar_config
(
desc
=
"Sampling"
,
leave
=
False
,
dynamic_ncols
=
True
,
position
=
1
)
...
...
@@ -61,43 +61,61 @@ def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forw
assert
task
in
[
"t2i"
,
"fill"
]
processor
=
None
for
row
in
tqdm
(
dataset
):
filename
=
row
[
"filename"
]
prompt
=
row
[
"prompt"
]
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
)
for
row
in
tqdm
(
dataloader
):
filenames
=
row
[
"filename"
]
prompts
=
row
[
"prompt"
]
_forward_kwargs
=
{
k
:
v
for
k
,
v
in
forward_kwargs
.
items
()}
if
task
==
"canny"
:
assert
forward_kwargs
.
get
(
"height"
,
1024
)
==
1024
assert
forward_kwargs
.
get
(
"width"
,
1024
)
==
1024
control_image
=
load_image
(
row
[
"canny_image_path"
])
control_image
=
processor
(
control_image
,
low_threshold
=
50
,
high_threshold
=
200
,
detect_resolution
=
1024
,
image_resolution
=
1024
,
)
_forward_kwargs
[
"control_image"
]
=
control_image
control_images
=
[]
for
canny_image_path
in
row
[
"canny_image_path"
]:
control_image
=
load_image
(
canny_image_path
)
control_image
=
processor
(
control_image
,
low_threshold
=
50
,
high_threshold
=
200
,
detect_resolution
=
1024
,
image_resolution
=
1024
,
)
control_images
.
append
(
control_image
)
_forward_kwargs
[
"control_image"
]
=
control_images
elif
task
==
"depth"
:
control_image
=
load_image
(
row
[
"depth_image_path"
])
control_image
=
processor
(
control_image
)[
0
].
convert
(
"RGB"
)
_forward_kwargs
[
"control_image"
]
=
control_image
control_images
=
[]
for
depth_image_path
in
row
[
"depth_image_path"
]:
control_image
=
load_image
(
depth_image_path
)
control_image
=
processor
(
control_image
)[
0
].
convert
(
"RGB"
)
control_images
.
append
(
control_image
)
_forward_kwargs
[
"control_image"
]
=
control_images
elif
task
==
"fill"
:
image
=
load_image
(
row
[
"image_path"
])
mask_image
=
load_image
(
row
[
"mask_image_path"
])
_forward_kwargs
[
"image"
]
=
image
_forward_kwargs
[
"mask_image"
]
=
mask_image
images
,
mask_images
=
[],
[]
for
image_path
,
mask_image_path
in
zip
(
row
[
"image_path"
],
row
[
"mask_image_path"
]):
image
=
load_image
(
image_path
)
mask_image
=
load_image
(
mask_image_path
)
images
.
append
(
image
)
mask_images
.
append
(
mask_image
)
_forward_kwargs
[
"image"
]
=
images
_forward_kwargs
[
"mask_image"
]
=
mask_images
elif
task
==
"redux"
:
image
=
load_image
(
row
[
"image_path"
])
_forward_kwargs
.
update
(
processor
(
image
))
images
=
[]
for
image_path
in
row
[
"image_path"
]:
image
=
load_image
(
image_path
)
images
.
append
(
image
)
_forward_kwargs
.
update
(
processor
(
images
))
seed
=
hash_str_to_int
(
filename
)
seeds
=
[
hash_str_to_int
(
filename
)
for
filename
in
filenames
]
generators
=
[
torch
.
Generator
().
manual_seed
(
seed
)
for
seed
in
seeds
]
if
task
==
"redux"
:
image
=
pipeline
(
generator
=
torch
.
G
enerator
().
manual_seed
(
seed
)
,
**
_forward_kwargs
).
images
[
0
]
image
s
=
pipeline
(
generator
=
g
enerator
s
,
**
_forward_kwargs
).
images
else
:
image
=
pipeline
(
prompt
,
generator
=
torch
.
Generator
().
manual_seed
(
seed
),
**
_forward_kwargs
).
images
[
0
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
f
"
{
filename
}
.png"
))
images
=
pipeline
(
prompts
,
generator
=
generators
,
**
_forward_kwargs
).
images
for
i
,
image
in
enumerate
(
images
):
filename
=
filenames
[
i
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
f
"
{
filename
}
.png"
))
torch
.
cuda
.
empty_cache
()
...
...
@@ -105,6 +123,7 @@ def run_test(
precision
:
str
=
"int4"
,
model_name
:
str
=
"flux.1-schnell"
,
dataset_name
:
str
=
"MJHQ"
,
batch_size
:
int
=
1
,
task
:
str
=
"t2i"
,
dtype
:
str
|
torch
.
dtype
=
torch
.
bfloat16
,
# the full precision dtype
height
:
int
=
1024
,
...
...
@@ -185,6 +204,7 @@ def run_test(
pipeline
.
set_adapters
([
f
"lora_
{
i
}
"
for
i
in
range
(
len
(
lora_names
))],
lora_strengths
)
run_pipeline
(
batch_size
=
batch_size
,
dataset
=
dataset
,
task
=
task
,
pipeline
=
pipeline
,
...
...
@@ -255,6 +275,7 @@ def run_test(
else
:
pipeline
=
pipeline
.
to
(
"cuda"
)
run_pipeline
(
batch_size
=
batch_size
,
dataset
=
dataset
,
task
=
task
,
pipeline
=
pipeline
,
...
...
@@ -272,4 +293,4 @@ def run_test(
torch
.
cuda
.
empty_cache
()
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.
25
assert
lpips
<
expected_lpips
*
1.
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment