Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
140b21e5
Unverified
Commit
140b21e5
authored
May 18, 2025
by
Muyang Li
Committed by
GitHub
May 18, 2025
Browse files
release: v0.3.0dev1 pre-release
release: v0.3.0dev1 pre-release
parents
2eedc2cb
f828be33
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
111 additions
and
28 deletions
+111
-28
app/flux.1/sketch/run_gradio.py
app/flux.1/sketch/run_gradio.py
+5
-1
examples/flux.1-dev-pulid.py
examples/flux.1-dev-pulid.py
+1
-1
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+6
-7
nunchaku/lora/flux/compose.py
nunchaku/lora/flux/compose.py
+1
-1
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+3
-0
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+11
-0
tests/flux/test_flux_schnell.py
tests/flux/test_flux_schnell.py
+1
-1
tests/flux/test_flux_teacache.py
tests/flux/test_flux_teacache.py
+1
-1
tests/flux/test_lora_reset.py
tests/flux/test_lora_reset.py
+47
-0
tests/flux/test_multiple_batch.py
tests/flux/test_multiple_batch.py
+1
-1
tests/utils.py
tests/utils.py
+34
-15
No files found.
app/flux.1/sketch/run_gradio.py
View file @
140b21e5
...
@@ -65,7 +65,11 @@ def save_image(img):
...
@@ -65,7 +65,11 @@ def save_image(img):
def
run
(
image
,
prompt
:
str
,
prompt_template
:
str
,
sketch_guidance
:
float
,
seed
:
int
)
->
tuple
[
Image
,
str
]:
def
run
(
image
,
prompt
:
str
,
prompt_template
:
str
,
sketch_guidance
:
float
,
seed
:
int
)
->
tuple
[
Image
,
str
]:
print
(
f
"Prompt:
{
prompt
}
"
)
print
(
f
"Prompt:
{
prompt
}
"
)
image_numpy
=
np
.
array
(
image
[
"composite"
].
convert
(
"RGB"
))
if
image
[
"composite"
]
is
None
:
image_numpy
=
np
.
array
(
blank_image
.
convert
(
"RGB"
))
else
:
image_numpy
=
np
.
array
(
image
[
"composite"
].
convert
(
"RGB"
))
if
prompt
.
strip
()
==
""
and
(
np
.
sum
(
image_numpy
==
255
)
>=
3145628
or
np
.
sum
(
image_numpy
==
0
)
>=
3145628
):
if
prompt
.
strip
()
==
""
and
(
np
.
sum
(
image_numpy
==
255
)
>=
3145628
or
np
.
sum
(
image_numpy
==
0
)
>=
3145628
):
return
blank_image
,
"Please input the prompt or draw something."
return
blank_image
,
"Please input the prompt or draw something."
...
...
examples/flux.1-dev-pulid.py
View file @
140b21e5
...
@@ -22,7 +22,7 @@ pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
...
@@ -22,7 +22,7 @@ pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image
=
load_image
(
"https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true"
)
id_image
=
load_image
(
"https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true"
)
image
=
pipeline
(
image
=
pipeline
(
"A woman holding a sign that says 'SVDQuant is fast!"
,
"A woman holding a sign that says 'SVDQuant is fast!
'
"
,
id_image
=
id_image
,
id_image
=
id_image
,
id_weight
=
1
,
id_weight
=
1
,
num_inference_steps
=
12
,
num_inference_steps
=
12
,
...
...
nunchaku/caching/utils.py
View file @
140b21e5
...
@@ -390,19 +390,18 @@ class FluxCachedTransformerBlocks(nn.Module):
...
@@ -390,19 +390,18 @@ class FluxCachedTransformerBlocks(nn.Module):
original_dtype
=
hidden_states
.
dtype
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
original_device
=
hidden_states
.
device
hidden_states
=
hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
hidden_states
=
hidden_states
.
to
(
self
.
dtype
).
to
(
original_
device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
original_
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
original_
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
original_
device
)
if
controlnet_block_samples
is
not
None
:
if
controlnet_block_samples
is
not
None
:
controlnet_block_samples
=
(
controlnet_block_samples
=
(
torch
.
stack
(
controlnet_block_samples
).
to
(
self
.
device
)
if
len
(
controlnet_block_samples
)
>
0
else
None
torch
.
stack
(
controlnet_block_samples
).
to
(
original_
device
)
if
len
(
controlnet_block_samples
)
>
0
else
None
)
)
if
controlnet_single_block_samples
is
not
None
and
len
(
controlnet_single_block_samples
)
>
0
:
if
controlnet_single_block_samples
is
not
None
and
len
(
controlnet_single_block_samples
)
>
0
:
controlnet_single_block_samples
=
(
controlnet_single_block_samples
=
(
torch
.
stack
(
controlnet_single_block_samples
).
to
(
self
.
device
)
torch
.
stack
(
controlnet_single_block_samples
).
to
(
original_
device
)
if
len
(
controlnet_single_block_samples
)
>
0
if
len
(
controlnet_single_block_samples
)
>
0
else
None
else
None
)
)
...
...
nunchaku/lora/flux/compose.py
View file @
140b21e5
...
@@ -136,4 +136,4 @@ if __name__ == "__main__":
...
@@ -136,4 +136,4 @@ if __name__ == "__main__":
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output safetensors file"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output safetensors file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
assert
len
(
args
.
input_paths
)
==
len
(
args
.
strengths
)
assert
len
(
args
.
input_paths
)
==
len
(
args
.
strengths
)
composed
=
compose_lora
(
list
(
zip
(
args
.
input_paths
,
args
.
strengths
)))
compose_lora
(
list
(
zip
(
args
.
input_paths
,
args
.
strengths
))
,
args
.
output_path
)
nunchaku/lora/flux/nunchaku_converter.py
View file @
140b21e5
...
@@ -117,6 +117,9 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
...
@@ -117,6 +117,9 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
if
orig_lora
[
0
]
is
None
or
orig_lora
[
1
]
is
None
:
if
orig_lora
[
0
]
is
None
or
orig_lora
[
1
]
is
None
:
assert
orig_lora
[
0
]
is
None
and
orig_lora
[
1
]
is
None
assert
orig_lora
[
0
]
is
None
and
orig_lora
[
1
]
is
None
orig_lora
=
None
orig_lora
=
None
elif
orig_lora
[
0
].
numel
()
==
0
or
orig_lora
[
1
].
numel
()
==
0
:
assert
orig_lora
[
0
].
numel
()
==
0
and
orig_lora
[
1
].
numel
()
==
0
orig_lora
=
None
else
:
else
:
assert
orig_lora
[
0
]
is
not
None
and
orig_lora
[
1
]
is
not
None
assert
orig_lora
[
0
]
is
not
None
and
orig_lora
[
1
]
is
not
None
orig_lora
=
(
orig_lora
=
(
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
140b21e5
...
@@ -333,6 +333,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -333,6 +333,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
elif
"qweight"
in
k
:
elif
"qweight"
in
k
:
# only the shape information of this tensor is needed
# only the shape information of this tensor is needed
new_quantized_part_sd
[
k
]
=
v
.
to
(
"meta"
)
new_quantized_part_sd
[
k
]
=
v
.
to
(
"meta"
)
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
for
t
in
[
"lora_up"
,
"lora_down"
]:
new_k
=
k
.
replace
(
".qweight"
,
f
".
{
t
}
"
)
if
new_k
not
in
quantized_part_sd
:
oc
,
ic
=
v
.
shape
ic
=
ic
*
2
# v is packed into INT8, so we need to double the size
new_quantized_part_sd
[
k
.
replace
(
".qweight"
,
f
".
{
t
}
"
)]
=
torch
.
zeros
(
(
0
,
ic
)
if
t
==
"lora_down"
else
(
oc
,
0
),
device
=
v
.
device
,
dtype
=
torch
.
bfloat16
)
elif
"lora"
in
k
:
elif
"lora"
in
k
:
new_quantized_part_sd
[
k
]
=
v
new_quantized_part_sd
[
k
]
=
v
transformer
.
_quantized_part_sd
=
new_quantized_part_sd
transformer
.
_quantized_part_sd
=
new_quantized_part_sd
...
...
tests/flux/test_flux_schnell.py
View file @
140b21e5
...
@@ -11,7 +11,7 @@ from .utils import run_test
...
@@ -11,7 +11,7 @@ from .utils import run_test
[
[
(
1024
,
1024
,
"flashattn2"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.126
),
(
1024
,
1024
,
"flashattn2"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.126
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.126
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.126
if
get_precision
()
==
"int4"
else
0.126
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.1
58
if
get_precision
()
==
"int4"
else
0.138
),
(
1920
,
1080
,
"nunchaku-fp16"
,
False
,
0.1
90
if
get_precision
()
==
"int4"
else
0.138
),
(
2048
,
2048
,
"nunchaku-fp16"
,
True
,
0.166
if
get_precision
()
==
"int4"
else
0.120
),
(
2048
,
2048
,
"nunchaku-fp16"
,
True
,
0.166
if
get_precision
()
==
"int4"
else
0.120
),
],
],
)
)
...
...
tests/flux/test_flux_teacache.py
View file @
140b21e5
...
@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
...
@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
"waterfall"
,
"waterfall"
,
23
,
23
,
0.6
,
0.6
,
0.2
26
if
get_precision
()
==
"int4"
else
0.226
,
0.2
53
if
get_precision
()
==
"int4"
else
0.226
,
),
),
],
],
)
)
...
...
tests/flux/test_lora_reset.py
0 → 100644
View file @
140b21e5
import
os
import
torch
from
diffusers
import
FluxPipeline
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
from
..utils
import
compute_lpips
def
test_lora_reset
():
precision
=
get_precision
()
# auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-dev"
,
offload
=
True
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
.
enable_sequential_cpu_offload
()
save_dir
=
os
.
path
.
join
(
"test_results"
,
"bf16"
,
"flux"
,
"lora_reset"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
image
=
pipeline
(
"cozy mountain cabin covered in snow, with smoke curling from the chimney and a warm, inviting light spilling through the windows"
,
# noqa: E501
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
generator
=
torch
.
Generator
().
manual_seed
(
23
),
).
images
[
0
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
"before.png"
))
transformer
.
update_lora_params
(
"alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors"
)
transformer
.
set_lora_strength
(
50
)
transformer
.
reset_lora
()
image
=
pipeline
(
"cozy mountain cabin covered in snow, with smoke curling from the chimney and a warm, inviting light spilling through the windows"
,
# noqa: E501
num_inference_steps
=
8
,
guidance_scale
=
3.5
,
generator
=
torch
.
Generator
().
manual_seed
(
23
),
).
images
[
0
]
image
.
save
(
os
.
path
.
join
(
save_dir
,
"after.png"
))
lpips
=
compute_lpips
(
os
.
path
.
join
(
save_dir
,
"before.png"
),
os
.
path
.
join
(
save_dir
,
"after.png"
))
print
(
f
"LPIPS:
{
lpips
}
"
)
assert
lpips
<
0.158
*
1.1
tests/flux/test_multiple_batch.py
View file @
140b21e5
...
@@ -11,7 +11,7 @@ from .utils import run_test
...
@@ -11,7 +11,7 @@ from .utils import run_test
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size"
,
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size"
,
[
[
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.140
if
get_precision
()
==
"int4"
else
0.135
,
2
),
(
1024
,
1024
,
"nunchaku-fp16"
,
False
,
0.140
if
get_precision
()
==
"int4"
else
0.135
,
2
),
(
1920
,
1080
,
"flashattn2"
,
True
,
0.1
60
if
get_precision
()
==
"int4"
else
0.123
,
4
),
(
1920
,
1080
,
"flashattn2"
,
True
,
0.1
77
if
get_precision
()
==
"int4"
else
0.123
,
4
),
],
],
)
)
def
test_flux_schnell
(
def
test_flux_schnell
(
...
...
tests/utils.py
View file @
140b21e5
...
@@ -28,15 +28,32 @@ def already_generate(save_dir: str, num_images) -> bool:
...
@@ -28,15 +28,32 @@ def already_generate(save_dir: str, num_images) -> bool:
class
MultiImageDataset
(
data
.
Dataset
):
class
MultiImageDataset
(
data
.
Dataset
):
def
__init__
(
self
,
gen_dirpath
:
str
,
ref_dir
path
:
str
|
datasets
.
Dataset
):
def
__init__
(
self
,
gen_dirpath
_or_image_path
:
str
,
ref_dirpath_or_image_
path
:
str
|
datasets
.
Dataset
):
super
(
data
.
Dataset
,
self
).
__init__
()
super
(
data
.
Dataset
,
self
).
__init__
()
self
.
gen_names
=
sorted
(
if
os
.
path
.
isdir
(
gen_dirpath_or_image_path
):
[
name
for
name
in
os
.
listdir
(
gen_dirpath
)
if
name
.
endswith
(
".png"
)
or
name
.
endswith
(
".jpg"
)]
self
.
gen_names
=
sorted
(
)
[
self
.
ref_names
=
sorted
(
name
[
name
for
name
in
os
.
listdir
(
ref_dirpath
)
if
name
.
endswith
(
".png"
)
or
name
.
endswith
(
".jpg"
)]
for
name
in
os
.
listdir
(
gen_dirpath_or_image_path
)
)
if
name
.
endswith
(
".png"
)
or
name
.
endswith
(
".jpg"
)
self
.
gen_dirpath
,
self
.
ref_dirpath
=
gen_dirpath
,
ref_dirpath
]
)
self
.
gen_dirpath
=
gen_dirpath_or_image_path
else
:
self
.
gen_names
=
[
os
.
path
.
basename
(
gen_dirpath_or_image_path
)]
self
.
gen_dirpath
=
os
.
path
.
dirname
(
gen_dirpath_or_image_path
)
if
os
.
path
.
isdir
(
ref_dirpath_or_image_path
):
self
.
ref_names
=
sorted
(
[
name
for
name
in
os
.
listdir
(
ref_dirpath_or_image_path
)
if
name
.
endswith
(
".png"
)
or
name
.
endswith
(
".jpg"
)
]
)
self
.
ref_dirpath
=
ref_dirpath_or_image_path
else
:
self
.
ref_names
=
[
os
.
path
.
basename
(
ref_dirpath_or_image_path
)]
self
.
ref_dirpath
=
os
.
path
.
dirname
(
ref_dirpath_or_image_path
)
assert
len
(
self
.
ref_names
)
==
len
(
self
.
gen_names
)
assert
len
(
self
.
ref_names
)
==
len
(
self
.
gen_names
)
self
.
transform
=
torchvision
.
transforms
.
ToTensor
()
self
.
transform
=
torchvision
.
transforms
.
ToTensor
()
...
@@ -45,10 +62,8 @@ class MultiImageDataset(data.Dataset):
...
@@ -45,10 +62,8 @@ class MultiImageDataset(data.Dataset):
return
len
(
self
.
ref_names
)
return
len
(
self
.
ref_names
)
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
):
name
=
self
.
ref_names
[
idx
]
ref_image
=
Image
.
open
(
os
.
path
.
join
(
self
.
ref_dirpath
,
self
.
ref_names
[
idx
])).
convert
(
"RGB"
)
assert
name
==
self
.
gen_names
[
idx
]
gen_image
=
Image
.
open
(
os
.
path
.
join
(
self
.
gen_dirpath
,
self
.
gen_names
[
idx
])).
convert
(
"RGB"
)
ref_image
=
Image
.
open
(
os
.
path
.
join
(
self
.
ref_dirpath
,
name
)).
convert
(
"RGB"
)
gen_image
=
Image
.
open
(
os
.
path
.
join
(
self
.
gen_dirpath
,
name
)).
convert
(
"RGB"
)
gen_size
=
gen_image
.
size
gen_size
=
gen_image
.
size
ref_size
=
ref_image
.
size
ref_size
=
ref_image
.
size
if
ref_size
!=
gen_size
:
if
ref_size
!=
gen_size
:
...
@@ -59,16 +74,20 @@ class MultiImageDataset(data.Dataset):
...
@@ -59,16 +74,20 @@ class MultiImageDataset(data.Dataset):
def
compute_lpips
(
def
compute_lpips
(
ref_dirpath
:
str
,
gen_dirpath
:
str
,
batch_size
:
int
=
4
,
num_workers
:
int
=
0
,
device
:
str
|
torch
.
device
=
"cuda"
ref_dirpath_or_image_path
:
str
,
gen_dirpath_or_image_path
:
str
,
batch_size
:
int
=
4
,
num_workers
:
int
=
0
,
device
:
str
|
torch
.
device
=
"cuda"
,
)
->
float
:
)
->
float
:
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
metric
=
LearnedPerceptualImagePatchSimilarity
(
normalize
=
True
).
to
(
device
)
metric
=
LearnedPerceptualImagePatchSimilarity
(
normalize
=
True
).
to
(
device
)
dataset
=
MultiImageDataset
(
gen_dirpath
,
ref_dir
path
)
dataset
=
MultiImageDataset
(
gen_dirpath
_or_image_path
,
ref_dirpath_or_image_
path
)
dataloader
=
data
.
DataLoader
(
dataloader
=
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
num_workers
,
shuffle
=
False
,
drop_last
=
False
dataset
,
batch_size
=
batch_size
,
num_workers
=
num_workers
,
shuffle
=
False
,
drop_last
=
False
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
desc
=
(
os
.
path
.
basename
(
gen_dirpath
))
+
" LPIPS"
desc
=
(
os
.
path
.
basename
(
gen_dirpath
_or_image_path
))
+
" LPIPS"
for
i
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
desc
)):
for
i
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
desc
)):
batch
=
[
tensor
.
to
(
device
)
for
tensor
in
batch
]
batch
=
[
tensor
.
to
(
device
)
for
tensor
in
batch
]
metric
.
update
(
batch
[
0
],
batch
[
1
])
metric
.
update
(
batch
[
0
],
batch
[
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment