Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
83c12f2b
Commit
83c12f2b
authored
Apr 11, 2025
by
PengGao
Committed by
GitHub
Apr 11, 2025
Browse files
feat: refactor LoRA handling and add run script (#16)
parent
4eec372d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
20 deletions
+70
-20
lightx2v/text2v/models/networks/wan/lora_adapter.py
lightx2v/text2v/models/networks/wan/lora_adapter.py
+20
-20
scripts/run_wan_i2v_with_lora.sh
scripts/run_wan_i2v_with_lora.sh
+50
-0
No files found.
lightx2v/text2v/models/networks/wan/lora_adapter.py
View file @
83c12f2b
...
...
@@ -8,20 +8,20 @@ import gc
class
WanLoraWrapper
:
def
__init__
(
self
,
wan_model
):
self
.
model
=
wan_model
self
.
lora_
dict
=
{}
self
.
override_dict
=
{}
self
.
lora_
metadata
=
{}
self
.
override_dict
=
{}
# On CPU
def
load_lora
(
self
,
lora_path
,
lora_name
=
None
):
if
lora_name
is
None
:
lora_name
=
os
.
path
.
basename
(
lora_path
).
split
(
"."
)[
0
]
if
lora_name
in
self
.
lora_
dict
:
if
lora_name
in
self
.
lora_
metadata
:
logger
.
info
(
f
"LoRA
{
lora_name
}
already loaded, skipping..."
)
return
lora_name
lora_weights
=
self
.
_load_lora_file
(
lora_path
)
self
.
lora_metadata
[
lora_name
]
=
{
"path"
:
lora_path
}
logger
.
info
(
f
"Registered LoRA metadata for:
{
lora_name
}
from
{
lora_path
}
"
)
self
.
lora_dict
[
lora_name
]
=
lora_weights
return
lora_name
def
_load_lora_file
(
self
,
file_path
):
...
...
@@ -36,7 +36,7 @@ class WanLoraWrapper:
return
tensor_dict
def
apply_lora
(
self
,
lora_name
,
alpha
=
1.0
):
if
lora_name
not
in
self
.
lora_
dict
:
if
lora_name
not
in
self
.
lora_
metadata
:
logger
.
info
(
f
"LoRA
{
lora_name
}
not found. Please load it first."
)
if
hasattr
(
self
.
model
,
"current_lora"
)
and
self
.
model
.
current_lora
:
...
...
@@ -46,19 +46,16 @@ class WanLoraWrapper:
logger
.
error
(
"Model does not have 'original_weight_dict'. Cannot apply LoRA."
)
return
False
lora_weights
=
self
.
_load_lora_file
(
self
.
lora_metadata
[
lora_name
][
"path"
])
weight_dict
=
self
.
model
.
original_weight_dict
lora_weights
=
self
.
lora_dict
[
lora_name
]
self
.
_apply_lora_weights
(
weight_dict
,
lora_weights
,
alpha
)
# 重新加载权重
self
.
model
.
pre_weight
.
load_weights
(
weight_dict
)
self
.
model
.
post_weight
.
load_weights
(
weight_dict
)
self
.
model
.
transformer_weights
.
load_weights
(
weight_dict
)
self
.
model
.
_init_weights
(
weight_dict
)
self
.
model
.
current_lora
=
lora_name
logger
.
info
(
f
"Applied LoRA:
{
lora_name
}
with alpha=
{
alpha
}
"
)
return
True
@
torch
.
no_grad
()
def
_apply_lora_weights
(
self
,
weight_dict
,
lora_weights
,
alpha
):
lora_pairs
=
{}
prefix
=
"diffusion_model."
...
...
@@ -73,6 +70,9 @@ class WanLoraWrapper:
applied_count
=
0
for
name
,
param
in
weight_dict
.
items
():
if
name
in
lora_pairs
:
if
name
not
in
self
.
override_dict
:
self
.
override_dict
[
name
]
=
param
.
clone
().
cpu
()
name_lora_A
,
name_lora_B
=
lora_pairs
[
name
]
lora_A
=
lora_weights
[
name_lora_A
].
to
(
param
.
device
,
param
.
dtype
)
lora_B
=
lora_weights
[
name_lora_B
].
to
(
param
.
device
,
param
.
dtype
)
...
...
@@ -85,6 +85,7 @@ class WanLoraWrapper:
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
)
@
torch
.
no_grad
()
def
remove_lora
(
self
):
if
not
self
.
model
.
current_lora
:
logger
.
info
(
"No LoRA currently applied"
)
...
...
@@ -98,19 +99,18 @@ class WanLoraWrapper:
logger
.
info
(
f
"LoRA
{
self
.
model
.
current_lora
}
removed, restored
{
restored_count
}
weights"
)
self
.
model
.
pre_weight
.
load_weights
(
self
.
model
.
original_weight_dict
)
self
.
model
.
post_weight
.
load_weights
(
self
.
model
.
original_weight_dict
)
self
.
model
.
transformer_weights
.
load_weights
(
self
.
model
.
original_weight_dict
)
if
self
.
model
.
current_lora
and
self
.
model
.
current_lora
in
self
.
lora_dict
:
del
self
.
lora_dict
[
self
.
model
.
current_lora
]
self
.
override_dict
=
{}
self
.
model
.
_init_weights
(
self
.
model
.
original_weight_dict
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
if
self
.
model
.
current_lora
and
self
.
model
.
current_lora
in
self
.
lora_metadata
:
del
self
.
lora_metadata
[
self
.
model
.
current_lora
]
self
.
override_dict
=
{}
self
.
model
.
current_lora
=
None
def
list_loaded_loras
(
self
):
return
list
(
self
.
lora_
dict
.
keys
())
return
list
(
self
.
lora_
metadata
.
keys
())
def
get_current_lora
(
self
):
return
self
.
model
.
current_lora
scripts/run_wan_i2v_with_lora.sh
0 → 100755
View file @
83c12f2b
#!/bin/bash
# set path and first
script_dir
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
lightx2v_path
=
"
$(
dirname
"
$script_dir
"
)
"
model_path
=
/mnt/aigc/shared_data/cache/huggingface/hub/Wan2.1-I2V-14B-480P
config_path
=
$model_path
/config.json
lora_path
=
/mnt/aigc/shared_data/wan_quant/lora/toy_zoe_epoch_324.safetensors
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
config_path
}
"
]
;
then
echo
"Error: config_path is not set. Please set this variable first."
exit
1
fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
python
-m
lightx2v
\
--model_cls
wan2.1
\
--task
i2v
\
--model_path
$model_path
\
--prompt
"画面中的物体轻轻向上跃起,变成了外貌相似的毛绒玩具。毛绒玩具有着一双眼睛,它的颜色和之前的一样。然后,它开始跳跃起来。背景保持一致,气氛显得格外俏皮。"
\
--infer_steps
40
\
--target_video_length
81
\
--target_width
832
\
--target_height
480
\
--attention_type
flash_attn3
\
--seed
42
\
--sample_neg_promp
"画面过曝,模糊,文字,字幕"
\
--config_path
$config_path
\
--save_video_path
./output_lightx2v_wan_i2v.mp4
\
--sample_guide_scale
5
\
--sample_shift
5
\
--image_path
${
lightx2v_path
}
/assets/inputs/imgs/img_0.jpg
\
--lora_path
${
lora_path
}
\
--feature_caching
Tea
\
--mm_config
'{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\
# --mm_config '{"mm_type": "Default", "weight_auto_quant": true}' \
# --use_ret_steps \
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment