Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
37a27712
Unverified
Commit
37a27712
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
Merge pull request #340 from mit-han-lab/dev
feat: support PuLID, Double FBCache and TeaCache; better linter
parents
c1d6fc84
760ab022
Changes
192
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
312 additions
and
30 deletions
+312
-30
tests/flux/test_flux_dev_loras.py
tests/flux/test_flux_dev_loras.py
+1
-0
tests/flux/test_flux_dev_pulid.py
tests/flux/test_flux_dev_pulid.py
+52
-0
tests/flux/test_flux_double_fb_cache.py
tests/flux/test_flux_double_fb_cache.py
+43
-0
tests/flux/test_flux_examples.py
tests/flux/test_flux_examples.py
+4
-0
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+4
-3
tests/flux/test_flux_teacache.py
tests/flux/test_flux_teacache.py
+144
-0
tests/flux/test_flux_tools.py
tests/flux/test_flux_tools.py
+1
-0
tests/flux/test_multiple_batch.py
tests/flux/test_multiple_batch.py
+3
-2
tests/flux/test_shuttle_jaguar.py
tests/flux/test_shuttle_jaguar.py
+2
-1
tests/flux/test_turing.py
tests/flux/test_turing.py
+1
-0
tests/flux/utils.py
tests/flux/utils.py
+52
-23
tests/requirements.txt
tests/requirements.txt
+5
-1
No files found.
tests/flux/test_flux_dev_loras.py
View file @
37a27712
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
tests/flux/test_flux_dev_pulid.py
0 → 100644
View file @
37a27712
from
types
import
MethodType
import
numpy
as
np
import
pytest
import
torch
import
torch.nn.functional
as
F
from
diffusers.utils
import
load_image
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.models.pulid.pulid_forward
import
pulid_forward
from
nunchaku.models.pulid.utils
import
resize_numpy_image_long
from
nunchaku.pipeline.pipeline_flux_pulid
import
PuLIDFluxPipeline
from
nunchaku.utils
import
get_precision
,
is_turing
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
def
test_flux_dev_pulid
():
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-dev"
)
pipeline
=
PuLIDFluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
,
).
to
(
"cuda"
)
pipeline
.
transformer
.
forward
=
MethodType
(
pulid_forward
,
pipeline
.
transformer
)
id_image
=
load_image
(
"https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true"
)
image
=
pipeline
(
"A woman holding a sign that says hello world"
,
id_image
=
id_image
,
id_weight
=
1
,
num_inference_steps
=
12
,
guidance_scale
=
3.5
,
).
images
[
0
]
id_image
=
id_image
.
convert
(
"RGB"
)
id_image_numpy
=
np
.
array
(
id_image
)
id_image
=
resize_numpy_image_long
(
id_image_numpy
,
1024
)
id_embeddings
,
_
=
pipeline
.
pulid_model
.
get_id_embedding
(
id_image
)
output_image
=
image
.
convert
(
"RGB"
)
output_image_numpy
=
np
.
array
(
output_image
)
output_image
=
resize_numpy_image_long
(
output_image_numpy
,
1024
)
output_id_embeddings
,
_
=
pipeline
.
pulid_model
.
get_id_embedding
(
output_image
)
cosine_similarities
=
(
F
.
cosine_similarity
(
id_embeddings
.
view
(
32
,
2048
),
output_id_embeddings
.
view
(
32
,
2048
),
dim
=
1
).
mean
().
item
()
)
print
(
cosine_similarities
)
assert
cosine_similarities
>
0.93
tests/flux/test_flux_double_fb_cache.py
0 → 100644
View file @
37a27712
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
[
(
True
,
0.09
,
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.24
if
get_precision
()
==
"int4"
else
0.165
),
(
True
,
0.09
,
0.12
,
1024
,
1024
,
50
,
None
,
1
,
0.24
if
get_precision
()
==
"int4"
else
0.161
),
],
)
def
test_flux_dev_double_fb_cache
(
use_double_fb_cache
:
bool
,
residual_diff_threshold_multi
:
float
,
residual_diff_threshold_single
:
float
,
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
lora_name
:
str
,
lora_strength
:
float
,
expected_lpips
:
float
,
):
run_test
(
precision
=
get_precision
(),
model_name
=
"flux.1-dev"
,
dataset_name
=
"MJHQ"
if
lora_name
is
None
else
lora_name
,
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
3.5
,
use_qencoder
=
False
,
cpu_offload
=
False
,
lora_names
=
lora_name
,
lora_strengths
=
lora_strength
,
use_double_fb_cache
=
use_double_fb_cache
,
residual_diff_threshold_multi
=
residual_diff_threshold_multi
,
residual_diff_threshold_single
=
residual_diff_threshold_single
,
expected_lpips
=
expected_lpips
,
)
tests/flux/test_flux_examples.py
View file @
37a27712
import
gc
import
os
import
subprocess
import
pytest
import
torch
EXAMPLES_DIR
=
"./examples"
...
...
@@ -10,6 +12,8 @@ example_scripts = [f for f in os.listdir(EXAMPLES_DIR) if f.endswith(".py") and
@
pytest
.
mark
.
parametrize
(
"script_name"
,
example_scripts
)
def
test_example_script_runs
(
script_name
):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
script_path
=
os
.
path
.
join
(
EXAMPLES_DIR
,
script_name
)
result
=
subprocess
.
run
([
"python"
,
script_path
],
capture_output
=
True
,
text
=
True
)
print
(
f
"Running
{
script_path
}
-> Return code:
{
result
.
returncode
}
"
)
...
...
tests/flux/test_flux_schnell.py
View file @
37a27712
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
@@ -8,13 +9,13 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips"
,
[
(
1024
,
1024
,
"flashattn2"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.1
13
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.1
13
),
(
1024
,
1024
,
"flashattn2"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.1
26
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.1
26
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.158
if
get_precision
()
==
"int4"
else
0.138
),
(
2048
,
2048
,
"nunchaku-fp16"
,
True
,
0.166
if
get_precision
()
==
"int4"
else
0.120
),
],
)
def
test_
int4
_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
def
test_
flux
_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
):
run_test
(
precision
=
get_precision
(),
height
=
height
,
...
...
tests/flux/test_flux_teacache.py
0 → 100644
View file @
37a27712
import
gc
import
os
import
pytest
import
torch
from
diffusers.pipelines.flux.pipeline_flux
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.caching.teacache
import
TeaCache
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
already_generate
,
compute_lpips
,
offload_pipeline
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
"height,width,num_inference_steps,prompt,name,seed,threshold,expected_lpips"
,
[
(
1024
,
1024
,
30
,
"A cat holding a sign that says hello world"
,
"cat"
,
0
,
0.6
,
0.363
if
get_precision
()
==
"int4"
else
0.363
,
),
(
512
,
2048
,
25
,
"The brown fox jumps over the lazy dog"
,
"fox"
,
1234
,
0.7
,
0.349
if
get_precision
()
==
"int4"
else
0.349
,
),
(
1024
,
768
,
50
,
"A scene from the Titanic movie featuring the Muppets"
,
"muppets"
,
42
,
0.3
,
0.360
if
get_precision
()
==
"int4"
else
0.495
,
),
(
1024
,
768
,
50
,
"A crystal ball showing a waterfall"
,
"waterfall"
,
23
,
0.6
,
0.226
if
get_precision
()
==
"int4"
else
0.226
,
),
],
)
def
test_flux_teacache
(
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
prompt
:
str
,
name
:
str
,
seed
:
int
,
threshold
:
float
,
expected_lpips
:
float
,
):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
device
=
torch
.
device
(
"cuda"
)
precision
=
get_precision
()
ref_root
=
os
.
environ
.
get
(
"NUNCHAKU_TEST_CACHE_ROOT"
,
os
.
path
.
join
(
"test_results"
,
"ref"
))
results_dir_16_bit
=
os
.
path
.
join
(
ref_root
,
"bf16"
,
"flux.1-dev"
,
"teacache"
,
name
)
results_dir_4_bit
=
os
.
path
.
join
(
"test_results"
,
precision
,
"flux.1-dev"
,
"teacache"
,
name
)
os
.
makedirs
(
results_dir_16_bit
,
exist_ok
=
True
)
os
.
makedirs
(
results_dir_4_bit
,
exist_ok
=
True
)
# First, generate results with the 16-bit model
if
not
already_generate
(
results_dir_16_bit
,
1
):
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
torch_dtype
=
torch
.
bfloat16
)
# Possibly offload the model to CPU when GPU memory is scarce
pipeline
=
offload_pipeline
(
pipeline
)
result
=
pipeline
(
prompt
=
prompt
,
num_inference_steps
=
num_inference_steps
,
height
=
height
,
width
=
width
,
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
),
).
images
[
0
]
result
.
save
(
os
.
path
.
join
(
results_dir_16_bit
,
f
"
{
name
}
_
{
seed
}
.png"
))
# Clean up the 16-bit model
del
pipeline
.
transformer
del
pipeline
.
text_encoder
del
pipeline
.
text_encoder_2
del
pipeline
.
vae
del
pipeline
del
result
gc
.
collect
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
empty_cache
()
free
,
total
=
torch
.
cuda
.
mem_get_info
()
# bytes
print
(
f
"After 16-bit generation: Free:
{
free
/
1024
**
2
:.
0
f
}
MB / Total:
{
total
/
1024
**
2
:.
0
f
}
MB"
)
# Then, generate results with the 4-bit model
if
not
already_generate
(
results_dir_4_bit
,
1
):
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-dev"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
with
torch
.
inference_mode
():
with
TeaCache
(
model
=
pipeline
.
transformer
,
num_steps
=
num_inference_steps
,
rel_l1_thresh
=
threshold
,
enabled
=
True
):
result
=
pipeline
(
prompt
=
prompt
,
num_inference_steps
=
num_inference_steps
,
height
=
height
,
width
=
width
,
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
),
).
images
[
0
]
result
.
save
(
os
.
path
.
join
(
results_dir_4_bit
,
f
"
{
name
}
_
{
seed
}
.png"
))
# Clean up the 4-bit model
del
pipeline
del
transformer
gc
.
collect
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
empty_cache
()
free
,
total
=
torch
.
cuda
.
mem_get_info
()
# bytes
print
(
f
"After 4-bit generation: Free:
{
free
/
1024
**
2
:.
0
f
}
MB / Total:
{
total
/
1024
**
2
:.
0
f
}
MB"
)
lpips
=
compute_lpips
(
results_dir_16_bit
,
results_dir_4_bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.1
tests/flux/test_flux_tools.py
View file @
37a27712
...
...
@@ -2,6 +2,7 @@ import pytest
import
torch
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
tests/flux/test_multiple_batch.py
View file @
37a27712
...
...
@@ -2,6 +2,7 @@
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
@@ -9,11 +10,11 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size"
,
[
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.140
if
get_precision
()
==
"int4"
else
0.1
18
,
2
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.140
if
get_precision
()
==
"int4"
else
0.1
35
,
2
),
(
1920
,
1080
,
"flashattn2"
,
False
,
0.160
if
get_precision
()
==
"int4"
else
0.123
,
4
),
],
)
def
test_
int4
_schnell
(
def
test_
flux
_schnell
(
height
:
int
,
width
:
int
,
attention_impl
:
str
,
cpu_offload
:
bool
,
expected_lpips
:
float
,
batch_size
:
int
):
run_test
(
...
...
tests/flux/test_shuttle_jaguar.py
View file @
37a27712
import
pytest
from
.utils
import
run_test
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
@
pytest
.
mark
.
skipif
(
is_turing
(),
reason
=
"Skip tests due to using Turing GPUs"
)
@
pytest
.
mark
.
parametrize
(
...
...
tests/flux/test_turing.py
View file @
37a27712
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
tests/flux/utils.py
View file @
37a27712
...
...
@@ -12,7 +12,9 @@ from tqdm import tqdm
import
nunchaku
from
nunchaku
import
NunchakuFluxTransformer2dModel
,
NunchakuT5EncoderModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.lora.flux.compose
import
compose_lora
from
..data
import
get_dataset
from
..utils
import
already_generate
,
compute_lpips
,
hash_str_to_int
...
...
@@ -141,6 +143,9 @@ def run_test(
attention_impl
:
str
=
"flashattn2"
,
# "flashattn2" or "nunchaku-fp16"
cpu_offload
:
bool
=
False
,
cache_threshold
:
float
=
0
,
use_double_fb_cache
:
bool
=
False
,
residual_diff_threshold_multi
:
float
=
0
,
residual_diff_threshold_single
:
float
=
0
,
lora_names
:
str
|
list
[
str
]
|
None
=
None
,
lora_strengths
:
float
|
list
[
float
]
=
1.0
,
max_dataset_size
:
int
=
4
,
...
...
@@ -196,8 +201,6 @@ def run_test(
if
not
already_generate
(
save_dir_16bit
,
max_dataset_size
):
pipeline_init_kwargs
=
{
"text_encoder"
:
None
,
"text_encoder2"
:
None
}
if
task
==
"redux"
else
{}
pipeline
=
pipeline_cls
.
from_pretrained
(
model_id_16bit
,
torch_dtype
=
dtype
,
**
pipeline_init_kwargs
)
gpu_properties
=
torch
.
cuda
.
get_device_properties
(
0
)
gpu_memory
=
gpu_properties
.
total_memory
/
(
1024
**
2
)
if
len
(
lora_names
)
>
0
:
for
i
,
(
lora_name
,
lora_strength
)
in
enumerate
(
zip
(
lora_names
,
lora_strengths
)):
...
...
@@ -207,27 +210,7 @@ def run_test(
)
pipeline
.
set_adapters
([
f
"lora_
{
i
}
"
for
i
in
range
(
len
(
lora_names
))],
lora_strengths
)
if
gpu_memory
>
36
*
1024
:
pipeline
=
pipeline
.
to
(
"cuda"
)
elif
gpu_memory
<
26
*
1024
:
pipeline
.
transformer
.
enable_group_offload
(
onload_device
=
torch
.
device
(
"cuda"
),
offload_device
=
torch
.
device
(
"cpu"
),
offload_type
=
"leaf_level"
,
use_stream
=
True
,
)
if
pipeline
.
text_encoder
is
not
None
:
pipeline
.
text_encoder
.
to
(
"cuda"
)
if
pipeline
.
text_encoder_2
is
not
None
:
apply_group_offloading
(
pipeline
.
text_encoder_2
,
onload_device
=
torch
.
device
(
"cuda"
),
offload_type
=
"block_level"
,
num_blocks_per_group
=
2
,
)
pipeline
.
vae
.
to
(
"cuda"
)
else
:
pipeline
.
enable_model_cpu_offload
()
pipeline
=
offload_pipeline
(
pipeline
)
run_pipeline
(
batch_size
=
batch_size
,
...
...
@@ -259,6 +242,12 @@ def run_test(
precision_str
+=
"-co"
if
cache_threshold
>
0
:
precision_str
+=
f
"-cache
{
cache_threshold
}
"
if
use_double_fb_cache
:
precision_str
+=
"-dfb"
if
residual_diff_threshold_multi
>
0
:
precision_str
+=
f
"-rdm
{
residual_diff_threshold_multi
}
"
if
residual_diff_threshold_single
>
0
:
precision_str
+=
f
"-rds
{
residual_diff_threshold_single
}
"
if
i2f_mode
is
not
None
:
precision_str
+=
f
"-i2f
{
i2f_mode
}
"
if
batch_size
>
1
:
...
...
@@ -303,6 +292,15 @@ def run_test(
pipeline
.
enable_sequential_cpu_offload
()
else
:
pipeline
=
pipeline
.
to
(
"cuda"
)
if
use_double_fb_cache
:
apply_cache_on_pipe
(
pipeline
,
use_double_fb_cache
=
use_double_fb_cache
,
residual_diff_threshold_multi
=
residual_diff_threshold_multi
,
residual_diff_threshold_single
=
residual_diff_threshold_single
,
)
run_pipeline
(
batch_size
=
batch_size
,
dataset
=
dataset
,
...
...
@@ -324,3 +322,34 @@ def run_test(
lpips
=
compute_lpips
(
save_dir_16bit
,
save_dir_4bit
)
print
(
f
"lpips:
{
lpips
}
"
)
assert
lpips
<
expected_lpips
*
1.1
def
offload_pipeline
(
pipeline
:
FluxPipeline
)
->
FluxPipeline
:
gpu_properties
=
torch
.
cuda
.
get_device_properties
(
0
)
gpu_memory
=
gpu_properties
.
total_memory
/
(
1024
**
2
)
device
=
torch
.
device
(
"cuda"
)
cpu
=
torch
.
device
(
"cpu"
)
if
gpu_memory
>
36
*
1024
:
pipeline
=
pipeline
.
to
(
device
)
elif
gpu_memory
<
26
*
1024
:
pipeline
.
transformer
.
enable_group_offload
(
onload_device
=
device
,
offload_device
=
cpu
,
offload_type
=
"leaf_level"
,
use_stream
=
True
,
)
if
pipeline
.
text_encoder
is
not
None
:
pipeline
.
text_encoder
.
to
(
device
)
if
pipeline
.
text_encoder_2
is
not
None
:
apply_group_offloading
(
pipeline
.
text_encoder_2
,
onload_device
=
device
,
offload_type
=
"block_level"
,
num_blocks_per_group
=
2
,
)
pipeline
.
vae
.
to
(
device
)
else
:
pipeline
.
enable_model_cpu_offload
()
return
pipeline
tests/requirements.txt
View file @
37a27712
...
...
@@ -5,4 +5,8 @@ torchmetrics
mediapipe
controlnet_aux
peft
git+https://github.com/asomoza/image_gen_aux.git
\ No newline at end of file
git+https://github.com/asomoza/image_gen_aux.git
insightface
opencv-python
facexlib
onnxruntime
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment