Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e0fadc93
Commit
e0fadc93
authored
Feb 23, 2025
by
muyangli
Browse files
[minor] fix some corner cases in lora conversion
parent
d7896cb4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
10 deletions
+78
-10
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/lora/flux/comfyui_converter.py
nunchaku/lora/flux/comfyui_converter.py
+32
-2
nunchaku/lora/flux/convert.py
nunchaku/lora/flux/convert.py
+2
-2
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+43
-5
No files found.
nunchaku/__version__.py
View file @
e0fadc93
__version__
=
"0.1.
0
"
__version__
=
"0.1.
1
"
nunchaku/lora/flux/comfyui_converter.py
View file @
e0fadc93
# convert the comfyui lora to diffusers format
# convert the comfyui lora to diffusers format
import
argparse
import
os
import
os
import
torch
import
torch
...
@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
...
@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def
comfyui2diffusers
(
def
comfyui2diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
,
min_rank
:
int
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
...
@@ -16,9 +17,10 @@ def comfyui2diffusers(
...
@@ -16,9 +17,10 @@ def comfyui2diffusers(
tensors
=
input_lora
tensors
=
input_lora
new_tensors
=
{}
new_tensors
=
{}
max_alpha
=
0
for
k
,
v
in
tensors
.
items
():
for
k
,
v
in
tensors
.
items
():
if
"alpha"
in
k
:
if
"alpha"
in
k
:
max_alpha
=
max
(
max_alpha
,
v
.
max
().
item
())
continue
continue
new_k
=
k
.
replace
(
"lora_down"
,
"lora_A"
).
replace
(
"lora_up"
,
"lora_B"
)
new_k
=
k
.
replace
(
"lora_down"
,
"lora_A"
).
replace
(
"lora_up"
,
"lora_B"
)
if
"lora_unet_double_blocks_"
in
k
:
if
"lora_unet_double_blocks_"
in
k
:
...
@@ -72,8 +74,36 @@ def comfyui2diffusers(
...
@@ -72,8 +74,36 @@ def comfyui2diffusers(
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
new_tensors
[
new_k
]
=
v
new_tensors
[
new_k
]
=
v
if
min_rank
is
not
None
:
for
k
in
new_tensors
.
keys
():
v
=
new_tensors
[
k
]
if
"lora_A"
in
k
:
rank
=
v
.
shape
[
0
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
min_rank
,
v
.
shape
[
1
],
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:
rank
]
=
v
new_tensors
[
k
]
=
new_v
else
:
assert
"lora_B"
in
k
rank
=
v
.
shape
[
1
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
v
.
shape
[
0
],
min_rank
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:,
:
rank
]
=
v
new_tensors
[
k
]
=
new_v
if
output_path
is
not
None
:
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
save_file
(
new_tensors
,
output_path
)
return
new_tensors
return
new_tensors
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the comfyui lora safetensor file"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output diffusers safetensor file"
)
parser
.
add_argument
(
"--min-rank"
,
type
=
int
,
default
=
None
,
help
=
"minimum rank for the LoRA weights"
)
args
=
parser
.
parse_args
()
comfyui2diffusers
(
args
.
input_path
,
args
.
output_path
,
min_rank
=
args
.
min_rank
)
nunchaku/lora/flux/convert.py
View file @
e0fadc93
...
@@ -37,8 +37,8 @@ if __name__ == "__main__":
...
@@ -37,8 +37,8 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
not
args
.
output_root
:
if
not
args
.
output_root
:
# output to the parent directory of the
quantized model
safetensor file
# output to the parent directory of the
lora
safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
quant
_path
)
args
.
output_root
=
os
.
path
.
dirname
(
args
.
lora
_path
)
if
args
.
lora_name
is
None
:
if
args
.
lora_name
is
None
:
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
...
...
nunchaku/lora/flux/diffusers_converter.py
View file @
e0fadc93
# convert the diffusers lora to nunchaku format
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
"""Convert LoRA weights to Nunchaku format."""
import
typing
as
tp
import
typing
as
tp
import
torch
import
torch
...
@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
...
@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
update_state_dict
(
converted
,
converted
,
{
{
"lora_down"
:
lora
[
0
],
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
divisor
=
16
,
dim
=
1
),
},
},
prefix
=
converted_local_name
,
prefix
=
converted_local_name
,
)
)
...
@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
...
@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
update_state_dict
(
converted
,
converted
,
{
{
"lora_down"
:
lora
[
0
],
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
divisor
=
16
,
dim
=
1
),
},
},
prefix
=
converted_local_name
,
prefix
=
converted_local_name
,
)
)
...
@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
...
@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
for
component
in
[
"lora_A"
,
"lora_B"
]:
fc1_k
=
f
"
{
candidate_block_name
}
.proj_mlp.
{
component
}
.weight"
fc2_k
=
f
"
{
candidate_block_name
}
.proj_out.linears.1.
{
component
}
.weight"
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
extra_lora_dict
=
extra_lora_dict
,
...
@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
...
@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else
:
else
:
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
for
k
in
extra_lora_dict
.
keys
():
fc1_k
=
k
if
"ff.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff.net.0.proj"
,
"ff.net.2"
)
elif
"ff_context.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff_context.net.0.proj"
,
"ff_context.net.2"
)
else
:
continue
assert
fc2_k
in
extra_lora_dict
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
block_names
:
set
[
str
]
=
set
()
block_names
:
set
[
str
]
=
set
()
for
param_name
in
orig_state_dict
.
keys
():
for
param_name
in
orig_state_dict
.
keys
():
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
...
@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
...
@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
),
),
prefix
=
block_name
,
prefix
=
block_name
,
)
)
return
converted
return
converted
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment