Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
a8ad2d7d
Commit
a8ad2d7d
authored
Apr 11, 2025
by
lijiaqi2
Committed by
gaopeng
Apr 11, 2025
Browse files
feat: support LoRA
parent
6c18f54c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
18 deletions
+168
-18
lightx2v/__main__.py
lightx2v/__main__.py
+38
-12
lightx2v/text2v/models/networks/wan/lora_adapter.py
lightx2v/text2v/models/networks/wan/lora_adapter.py
+116
-0
lightx2v/text2v/models/networks/wan/model.py
lightx2v/text2v/models/networks/wan/model.py
+14
-6
No files found.
lightx2v/__main__.py
View file @
a8ad2d7d
import
argparse
import
argparse
from
contextlib
import
contextmanager
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
os
import
os
...
@@ -21,6 +22,8 @@ from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanS
...
@@ -21,6 +22,8 @@ from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanS
from
lightx2v.text2v.models.networks.hunyuan.model
import
HunyuanModel
from
lightx2v.text2v.models.networks.hunyuan.model
import
HunyuanModel
from
lightx2v.text2v.models.networks.wan.model
import
WanModel
from
lightx2v.text2v.models.networks.wan.model
import
WanModel
from
lightx2v.text2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.text2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.text2v.models.video_encoders.hf.wan.vae
import
WanVAE
...
@@ -29,6 +32,14 @@ from lightx2v.common.ops import *
...
@@ -29,6 +32,14 @@ from lightx2v.common.ops import *
from
lightx2v.image2v.models.wan.model
import
CLIPModel
from
lightx2v.image2v.models.wan.model
import
CLIPModel
@
contextmanager
def
time_duration
(
label
:
str
=
""
):
start_time
=
time
.
time
()
yield
end_time
=
time
.
time
()
print
(
f
"==>
{
label
}
start:
{
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
start_time
))
}
cost
{
end_time
-
start_time
:.
2
f
}
seconds"
)
def
load_models
(
args
,
model_config
):
def
load_models
(
args
,
model_config
):
if
model_config
[
"parallel_attn_type"
]:
if
model_config
[
"parallel_attn_type"
]:
cur_rank
=
dist
.
get_rank
()
# 获取当前进程的 rank
cur_rank
=
dist
.
get_rank
()
# 获取当前进程的 rank
...
@@ -59,15 +70,27 @@ def load_models(args, model_config):
...
@@ -59,15 +70,27 @@ def load_models(args, model_config):
shard_fn
=
None
,
shard_fn
=
None
,
)
)
text_encoders
=
[
text_encoder
]
text_encoders
=
[
text_encoder
]
model
=
WanModel
(
args
.
model_path
,
model_config
,
init_device
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
args
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
args
.
parallel_vae
)
with
time_duration
(
"Load Wan Model"
):
model
=
WanModel
(
args
.
model_path
,
model_config
,
init_device
)
if
args
.
lora_path
:
lora_wrapper
=
WanLoraWrapper
(
model
)
with
time_duration
(
"Load LoRA Model"
):
lora_name
=
lora_wrapper
.
load_lora
(
args
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
args
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
with
time_duration
(
"Load WAN VAE Model"
):
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
args
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
args
.
parallel_vae
)
if
args
.
task
==
"i2v"
:
if
args
.
task
==
"i2v"
:
image_encoder
=
CLIPModel
(
with
time_duration
(
"Load Image Encoder"
):
dtype
=
torch
.
float16
,
image_encoder
=
CLIPModel
(
device
=
init_device
,
dtype
=
torch
.
float16
,
checkpoint_path
=
os
.
path
.
join
(
args
.
model_path
,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
),
device
=
init_device
,
tokenizer_path
=
os
.
path
.
join
(
args
.
model_path
,
"xlm-roberta-large"
),
checkpoint_path
=
os
.
path
.
join
(
args
.
model_path
,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
),
)
tokenizer_path
=
os
.
path
.
join
(
args
.
model_path
,
"xlm-roberta-large"
),
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
...
@@ -312,6 +335,10 @@ if __name__ == "__main__":
...
@@ -312,6 +335,10 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--patch_size"
,
default
=
(
1
,
2
,
2
))
parser
.
add_argument
(
"--patch_size"
,
default
=
(
1
,
2
,
2
))
parser
.
add_argument
(
"--teacache_thresh"
,
type
=
float
,
default
=
0.26
)
parser
.
add_argument
(
"--teacache_thresh"
,
type
=
float
,
default
=
0.26
)
parser
.
add_argument
(
"--use_ret_steps"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--use_ret_steps"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--use_bfloat16"
,
action
=
"store_true"
,
default
=
True
)
parser
.
add_argument
(
"--lora_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--strength_model"
,
type
=
float
,
default
=
1.0
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -338,6 +365,7 @@ if __name__ == "__main__":
...
@@ -338,6 +365,7 @@ if __name__ == "__main__":
"feature_caching"
:
args
.
feature_caching
,
"feature_caching"
:
args
.
feature_caching
,
"parallel_attn_type"
:
args
.
parallel_attn_type
,
"parallel_attn_type"
:
args
.
parallel_attn_type
,
"parallel_vae"
:
args
.
parallel_vae
,
"parallel_vae"
:
args
.
parallel_vae
,
"use_bfloat16"
:
args
.
use_bfloat16
,
}
}
if
args
.
config_path
is
not
None
:
if
args
.
config_path
is
not
None
:
...
@@ -347,10 +375,8 @@ if __name__ == "__main__":
...
@@ -347,10 +375,8 @@ if __name__ == "__main__":
print
(
f
"model_config:
{
model_config
}
"
)
print
(
f
"model_config:
{
model_config
}
"
)
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
args
,
model_config
)
with
time_duration
(
"Load models"
):
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
args
,
model_config
)
load_models_time
=
time
.
time
()
print
(
f
"Load models cost:
{
load_models_time
-
start_time
}
"
)
if
args
.
task
in
[
"i2v"
]:
if
args
.
task
in
[
"i2v"
]:
image_encoder_output
=
run_image_encoder
(
args
,
image_encoder
,
vae_model
)
image_encoder_output
=
run_image_encoder
(
args
,
image_encoder
,
vae_model
)
...
...
lightx2v/text2v/models/networks/wan/lora_adapter.py
0 → 100644
View file @
a8ad2d7d
import
os
import
torch
from
safetensors
import
safe_open
from
loguru
import
logger
import
gc
class
WanLoraWrapper
:
def
__init__
(
self
,
wan_model
):
self
.
model
=
wan_model
self
.
lora_dict
=
{}
self
.
override_dict
=
{}
def
load_lora
(
self
,
lora_path
,
lora_name
=
None
):
if
lora_name
is
None
:
lora_name
=
os
.
path
.
basename
(
lora_path
).
split
(
"."
)[
0
]
if
lora_name
in
self
.
lora_dict
:
logger
.
info
(
f
"LoRA
{
lora_name
}
already loaded, skipping..."
)
return
lora_name
lora_weights
=
self
.
_load_lora_file
(
lora_path
)
self
.
lora_dict
[
lora_name
]
=
lora_weights
return
lora_name
def
_load_lora_file
(
self
,
file_path
):
use_bfloat16
=
True
# Default value
if
self
.
model
.
config
and
hasattr
(
self
.
model
.
config
,
"get"
):
use_bfloat16
=
self
.
model
.
config
.
get
(
"use_bfloat16"
,
True
)
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
if
use_bfloat16
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
for
key
in
f
.
keys
()}
else
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
return
tensor_dict
def
apply_lora
(
self
,
lora_name
,
alpha
=
1.0
):
if
lora_name
not
in
self
.
lora_dict
:
logger
.
info
(
f
"LoRA
{
lora_name
}
not found. Please load it first."
)
if
hasattr
(
self
.
model
,
"current_lora"
)
and
self
.
model
.
current_lora
:
self
.
remove_lora
()
if
not
hasattr
(
self
.
model
,
"original_weight_dict"
):
logger
.
error
(
"Model does not have 'original_weight_dict'. Cannot apply LoRA."
)
return
False
weight_dict
=
self
.
model
.
original_weight_dict
lora_weights
=
self
.
lora_dict
[
lora_name
]
self
.
_apply_lora_weights
(
weight_dict
,
lora_weights
,
alpha
)
# 重新加载权重
self
.
model
.
pre_weight
.
load_weights
(
weight_dict
)
self
.
model
.
post_weight
.
load_weights
(
weight_dict
)
self
.
model
.
transformer_weights
.
load_weights
(
weight_dict
)
self
.
model
.
current_lora
=
lora_name
logger
.
info
(
f
"Applied LoRA:
{
lora_name
}
with alpha=
{
alpha
}
"
)
return
True
def
_apply_lora_weights
(
self
,
weight_dict
,
lora_weights
,
alpha
):
lora_pairs
=
{}
prefix
=
"diffusion_model."
for
key
in
lora_weights
.
keys
():
if
key
.
endswith
(
"lora_A.weight"
)
and
key
.
startswith
(
prefix
):
base_name
=
key
[
len
(
prefix
)
:].
replace
(
"lora_A.weight"
,
"weight"
)
b_key
=
key
.
replace
(
"lora_A.weight"
,
"lora_B.weight"
)
if
b_key
in
lora_weights
:
lora_pairs
[
base_name
]
=
(
key
,
b_key
)
applied_count
=
0
for
name
,
param
in
weight_dict
.
items
():
if
name
in
lora_pairs
:
name_lora_A
,
name_lora_B
=
lora_pairs
[
name
]
lora_A
=
lora_weights
[
name_lora_A
].
to
(
param
.
device
,
param
.
dtype
)
lora_B
=
lora_weights
[
name_lora_B
].
to
(
param
.
device
,
param
.
dtype
)
param
+=
torch
.
matmul
(
lora_B
,
lora_A
)
*
alpha
applied_count
+=
1
logger
.
info
(
f
"Applied
{
applied_count
}
LoRA weight adjustments"
)
if
applied_count
==
0
:
logger
.
info
(
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
)
def
remove_lora
(
self
):
if
not
self
.
model
.
current_lora
:
logger
.
info
(
"No LoRA currently applied"
)
return
logger
.
info
(
f
"Removing LoRA
{
self
.
model
.
current_lora
}
..."
)
restored_count
=
0
for
k
,
v
in
self
.
override_dict
.
items
():
self
.
model
.
original_weight_dict
[
k
]
=
v
.
to
(
self
.
model
.
device
)
restored_count
+=
1
logger
.
info
(
f
"LoRA
{
self
.
model
.
current_lora
}
removed, restored
{
restored_count
}
weights"
)
self
.
model
.
pre_weight
.
load_weights
(
self
.
model
.
original_weight_dict
)
self
.
model
.
post_weight
.
load_weights
(
self
.
model
.
original_weight_dict
)
self
.
model
.
transformer_weights
.
load_weights
(
self
.
model
.
original_weight_dict
)
if
self
.
model
.
current_lora
and
self
.
model
.
current_lora
in
self
.
lora_dict
:
del
self
.
lora_dict
[
self
.
model
.
current_lora
]
self
.
override_dict
=
{}
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
list_loaded_loras
(
self
):
return
list
(
self
.
lora_dict
.
keys
())
def
get_current_lora
(
self
):
return
self
.
model
.
current_lora
lightx2v/text2v/models/networks/wan/model.py
View file @
a8ad2d7d
...
@@ -30,6 +30,7 @@ class WanModel:
...
@@ -30,6 +30,7 @@ class WanModel:
self
.
_init_infer_class
()
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_weights
()
self
.
_init_infer
()
self
.
_init_infer
()
self
.
current_lora
=
None
if
config
[
"parallel_attn_type"
]:
if
config
[
"parallel_attn_type"
]:
if
config
[
"parallel_attn_type"
]
==
"ulysses"
:
if
config
[
"parallel_attn_type"
]
==
"ulysses"
:
...
@@ -53,8 +54,12 @@ class WanModel:
...
@@ -53,8 +54,12 @@ class WanModel:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_safetensor_to_dict
(
self
,
file_path
):
def
_load_safetensor_to_dict
(
self
,
file_path
):
use_bfloat16
=
self
.
config
.
get
(
"use_bfloat16"
,
True
)
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
for
key
in
f
.
keys
()}
if
use_bfloat16
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
for
key
in
f
.
keys
()}
else
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
tensor_dict
return
tensor_dict
def
_load_ckpt
(
self
):
def
_load_ckpt
(
self
):
...
@@ -69,16 +74,19 @@ class WanModel:
...
@@ -69,16 +74,19 @@ class WanModel:
weight_dict
.
update
(
file_weights
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
return
weight_dict
def
_init_weights
(
self
):
def
_init_weights
(
self
,
weight_dict
=
None
):
weight_dict
=
self
.
_load_ckpt
()
if
weight_dict
is
None
:
self
.
original_weight_dict
=
self
.
_load_ckpt
()
else
:
self
.
original_weight_dict
=
weight_dict
# init weights
# init weights
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
# load weights
self
.
pre_weight
.
load_weights
(
weight_dict
)
self
.
pre_weight
.
load_weights
(
self
.
original_
weight_dict
)
self
.
post_weight
.
load_weights
(
weight_dict
)
self
.
post_weight
.
load_weights
(
self
.
original_
weight_dict
)
self
.
transformer_weights
.
load_weights
(
weight_dict
)
self
.
transformer_weights
.
load_weights
(
self
.
original_
weight_dict
)
def
_init_infer
(
self
):
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment