Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
089fa091
Commit
089fa091
authored
Jul 01, 2025
by
GoatWu
Browse files
converter support lora
parent
11fcc3fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
1 deletion
+96
-1
lightx2v/models/networks/wan/lora_adapter.py
lightx2v/models/networks/wan/lora_adapter.py
+1
-1
tools/convert/converter.py
tools/convert/converter.py
+65
-0
tools/convert/readme.md
tools/convert/readme.md
+15
-0
tools/convert/readme_zh.md
tools/convert/readme_zh.md
+15
-0
No files found.
lightx2v/models/networks/wan/lora_adapter.py
View file @
089fa091
...
@@ -96,7 +96,7 @@ class WanLoraWrapper:
...
@@ -96,7 +96,7 @@ class WanLoraWrapper:
name_diff
=
lora_diffs
[
name
]
name_diff
=
lora_diffs
[
name
]
lora_diff
=
lora_weights
[
name_diff
].
to
(
param
.
device
,
param
.
dtype
)
lora_diff
=
lora_weights
[
name_diff
].
to
(
param
.
device
,
param
.
dtype
)
param
+=
lora_diff
param
+=
lora_diff
*
alpha
applied_count
+=
1
applied_count
+=
1
logger
.
info
(
f
"Applied
{
applied_count
}
LoRA weight adjustments"
)
logger
.
info
(
f
"Applied
{
applied_count
}
LoRA weight adjustments"
)
...
...
tools/convert/converter.py
View file @
089fa091
...
@@ -398,6 +398,53 @@ def quantize_model(
...
@@ -398,6 +398,53 @@ def quantize_model(
return
weights
return
weights
def
load_loras
(
lora_path
,
weight_dict
,
alpha
):
logger
.
info
(
f
"Loading LoRA from:
{
lora_path
}
"
)
with
safe_open
(
lora_path
,
framework
=
"pt"
)
as
f
:
lora_weights
=
{
k
:
f
.
get_tensor
(
k
)
for
k
in
f
.
keys
()}
lora_pairs
=
{}
lora_diffs
=
{}
prefix
=
"diffusion_model."
def
try_lora_pair
(
key
,
suffix_a
,
suffix_b
,
target_suffix
):
if
key
.
endswith
(
suffix_a
):
base_name
=
key
[
len
(
prefix
)
:].
replace
(
suffix_a
,
target_suffix
)
pair_key
=
key
.
replace
(
suffix_a
,
suffix_b
)
if
pair_key
in
lora_weights
:
lora_pairs
[
base_name
]
=
(
key
,
pair_key
)
def
try_lora_diff
(
key
,
suffix
,
target_suffix
):
if
key
.
endswith
(
suffix
):
base_name
=
key
[
len
(
prefix
)
:].
replace
(
suffix
,
target_suffix
)
lora_diffs
[
base_name
]
=
key
for
key
in
lora_weights
.
keys
():
if
not
key
.
startswith
(
prefix
):
continue
try_lora_pair
(
key
,
"lora_A.weight"
,
"lora_B.weight"
,
"weight"
)
try_lora_pair
(
key
,
"lora_down.weight"
,
"lora_up.weight"
,
"weight"
)
try_lora_diff
(
key
,
"diff"
,
"weight"
)
try_lora_diff
(
key
,
"diff_b"
,
"bias"
)
applied_count
=
0
for
name
,
param
in
weight_dict
.
items
():
if
name
in
lora_pairs
:
name_lora_A
,
name_lora_B
=
lora_pairs
[
name
]
lora_A
=
lora_weights
[
name_lora_A
].
to
(
param
.
device
,
param
.
dtype
)
lora_B
=
lora_weights
[
name_lora_B
].
to
(
param
.
device
,
param
.
dtype
)
param
+=
torch
.
matmul
(
lora_B
,
lora_A
)
*
alpha
applied_count
+=
1
elif
name
in
lora_diffs
:
name_diff
=
lora_diffs
[
name
]
lora_diff
=
lora_weights
[
name_diff
].
to
(
param
.
device
,
param
.
dtype
)
param
+=
lora_diff
*
alpha
applied_count
+=
1
logger
.
info
(
f
"Applied
{
applied_count
}
LoRA weight adjustments"
)
def
convert_weights
(
args
):
def
convert_weights
(
args
):
if
os
.
path
.
isdir
(
args
.
source
):
if
os
.
path
.
isdir
(
args
.
source
):
src_files
=
glob
.
glob
(
os
.
path
.
join
(
args
.
source
,
"*.safetensors"
),
recursive
=
True
)
src_files
=
glob
.
glob
(
os
.
path
.
join
(
args
.
source
,
"*.safetensors"
),
recursive
=
True
)
...
@@ -423,6 +470,16 @@ def convert_weights(args):
...
@@ -423,6 +470,16 @@ def convert_weights(args):
raise
ValueError
(
f
"Duplicate keys found:
{
duplicate_keys
}
in file
{
file_path
}
"
)
raise
ValueError
(
f
"Duplicate keys found:
{
duplicate_keys
}
in file
{
file_path
}
"
)
merged_weights
.
update
(
weights
)
merged_weights
.
update
(
weights
)
if
args
.
lora_path
is
not
None
:
# Handle alpha list - if single alpha, replicate for all LoRAs
if
len
(
args
.
lora_alpha
)
==
1
and
len
(
args
.
lora_path
)
>
1
:
args
.
lora_alpha
=
args
.
lora_alpha
*
len
(
args
.
lora_path
)
elif
len
(
args
.
lora_alpha
)
!=
len
(
args
.
lora_path
):
raise
ValueError
(
f
"Number of lora_alpha (
{
len
(
args
.
lora_alpha
)
}
) must match number of lora_path (
{
len
(
args
.
lora_path
)
}
) or be 1"
)
for
path
,
alpha
in
zip
(
args
.
lora_path
,
args
.
lora_alpha
):
load_loras
(
path
,
merged_weights
,
alpha
)
if
args
.
direction
is
not
None
:
if
args
.
direction
is
not
None
:
rules
=
get_key_mapping_rules
(
args
.
direction
,
args
.
model_type
)
rules
=
get_key_mapping_rules
(
args
.
direction
,
args
.
model_type
)
converted_weights
=
{}
converted_weights
=
{}
...
@@ -584,6 +641,14 @@ def main():
...
@@ -584,6 +641,14 @@ def main():
choices
=
[
"torch.int8"
,
"torch.float8_e4m3fn"
],
choices
=
[
"torch.int8"
,
"torch.float8_e4m3fn"
],
help
=
"Data type for quantization"
,
help
=
"Data type for quantization"
,
)
)
parser
.
add_argument
(
"--lora_path"
,
type
=
str
,
nargs
=
"*"
,
help
=
"Path(s) to LoRA file(s). Can specify multiple paths separated by spaces."
)
parser
.
add_argument
(
"--lora_alpha"
,
type
=
float
,
nargs
=
"*"
,
default
=
[
1.0
],
help
=
"Alpha for LoRA weight scaling"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
quantized
:
if
args
.
quantized
:
...
...
tools/convert/readme.md
View file @
089fa091
...
@@ -50,6 +50,21 @@ python converter.py \
...
@@ -50,6 +50,21 @@ python converter.py \
--model_type
wan_dit
--model_type
wan_dit
```
```
### Wan DiT + LoRA
```
bash
python converter.py
\
--quantized
\
--source
/Path/To/Wan-AI/Wan2.1-T2V-14B/
\
--output
/Path/To/output
\
--output_ext
.safetensors
\
--output_name
wan_int8
\
--dtype
torch.int8
\
--model_type
wan_dit
\
--lora_path
/Path/To/LoRA1/ /Path/To/LoRA2/
\
--lora_alpha
1.0 1.0
```
### Hunyuan DIT
### Hunyuan DIT
```
bash
```
bash
...
...
tools/convert/readme_zh.md
View file @
089fa091
...
@@ -50,6 +50,21 @@ python converter.py \
...
@@ -50,6 +50,21 @@ python converter.py \
--model_type
wan_dit
--model_type
wan_dit
```
```
### Wan DiT + LoRA
```
bash
python converter.py
\
--quantized
\
--source
/Path/To/Wan-AI/Wan2.1-T2V-14B/
\
--output
/Path/To/output
\
--output_ext
.safetensors
\
--output_name
wan_int8
\
--dtype
torch.int8
\
--model_type
wan_dit
\
--lora_path
/Path/To/LoRA1/ /Path/To/LoRA2/
\
--lora_alpha
1.0 1.0
```
### Hunyuan DIT
### Hunyuan DIT
```
bash
```
bash
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment