Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e0fadc93
Commit
e0fadc93
authored
Feb 23, 2025
by
muyangli
Browse files
[minor] fix some corner cases in lora conversion
parent
d7896cb4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
10 deletions
+78
-10
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/lora/flux/comfyui_converter.py
nunchaku/lora/flux/comfyui_converter.py
+32
-2
nunchaku/lora/flux/convert.py
nunchaku/lora/flux/convert.py
+2
-2
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+43
-5
No files found.
nunchaku/__version__.py
View file @
e0fadc93
__version__
=
"0.1.
0
"
__version__
=
"0.1.
1
"
nunchaku/lora/flux/comfyui_converter.py
View file @
e0fadc93
# convert the comfyui lora to diffusers format
import
argparse
import
os
import
torch
...
...
@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def
comfyui2diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
,
min_rank
:
int
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
...
...
@@ -16,9 +17,10 @@ def comfyui2diffusers(
tensors
=
input_lora
new_tensors
=
{}
max_alpha
=
0
for
k
,
v
in
tensors
.
items
():
if
"alpha"
in
k
:
max_alpha
=
max
(
max_alpha
,
v
.
max
().
item
())
continue
new_k
=
k
.
replace
(
"lora_down"
,
"lora_A"
).
replace
(
"lora_up"
,
"lora_B"
)
if
"lora_unet_double_blocks_"
in
k
:
...
...
@@ -72,8 +74,36 @@ def comfyui2diffusers(
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
new_tensors
[
new_k
]
=
v
if
min_rank
is
not
None
:
for
k
in
new_tensors
.
keys
():
v
=
new_tensors
[
k
]
if
"lora_A"
in
k
:
rank
=
v
.
shape
[
0
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
min_rank
,
v
.
shape
[
1
],
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:
rank
]
=
v
new_tensors
[
k
]
=
new_v
else
:
assert
"lora_B"
in
k
rank
=
v
.
shape
[
1
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
v
.
shape
[
0
],
min_rank
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:,
:
rank
]
=
v
new_tensors
[
k
]
=
new_v
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
return
new_tensors
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the comfyui lora safetensor file"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output diffusers safetensor file"
)
parser
.
add_argument
(
"--min-rank"
,
type
=
int
,
default
=
None
,
help
=
"minimum rank for the LoRA weights"
)
args
=
parser
.
parse_args
()
comfyui2diffusers
(
args
.
input_path
,
args
.
output_path
,
min_rank
=
args
.
min_rank
)
nunchaku/lora/flux/convert.py
View file @
e0fadc93
...
...
@@ -37,8 +37,8 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
if
not
args
.
output_root
:
# output to the parent directory of the
quantized model
safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
quant
_path
)
# output to the parent directory of the
lora
safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
lora
_path
)
if
args
.
lora_name
is
None
:
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
...
...
nunchaku/lora/flux/diffusers_converter.py
View file @
e0fadc93
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
import
typing
as
tp
import
torch
...
...
@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
converted
,
{
"lora_down"
:
lora
[
0
],
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
divisor
=
16
,
dim
=
1
),
},
prefix
=
converted_local_name
,
)
...
...
@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
converted
,
{
"lora_down"
:
lora
[
0
],
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
divisor
=
16
,
dim
=
1
),
},
prefix
=
converted_local_name
,
)
...
...
@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
for
component
in
[
"lora_A"
,
"lora_B"
]:
fc1_k
=
f
"
{
candidate_block_name
}
.proj_mlp.
{
component
}
.weight"
fc2_k
=
f
"
{
candidate_block_name
}
.proj_out.linears.1.
{
component
}
.weight"
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
...
...
@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else
:
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
for
k
in
extra_lora_dict
.
keys
():
fc1_k
=
k
if
"ff.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff.net.0.proj"
,
"ff.net.2"
)
elif
"ff_context.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff_context.net.0.proj"
,
"ff_context.net.2"
)
else
:
continue
assert
fc2_k
in
extra_lora_dict
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
block_names
:
set
[
str
]
=
set
()
for
param_name
in
orig_state_dict
.
keys
():
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
...
...
@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
),
prefix
=
block_name
,
)
return
converted
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment