Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
682037cd
Unverified
Commit
682037cd
authored
Sep 27, 2025
by
gushiqiao
Committed by
GitHub
Sep 27, 2025
Browse files
[Feat] Add wan2.2 animate model (#339)
parent
e251e4dc
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
464 additions
and
29 deletions
+464
-29
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+11
-2
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+380
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+1
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+3
-3
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+3
-1
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+0
-1
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+47
-21
scripts/wan22/run_wan22_animate.sh
scripts/wan22/run_wan22_animate.sh
+19
-0
No files found.
lightx2v/models/runners/default_runner.py
View file @
682037cd
...
...
@@ -47,6 +47,8 @@ class DefaultRunner(BaseRunner):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_t2v
elif
self
.
config
[
"task"
]
==
"vace"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_vace
elif
self
.
config
[
"task"
]
==
"animate"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_animate
if
self
.
config
.
get
(
"compile"
,
False
):
logger
.
info
(
f
"[Compile] Compile all shapes:
{
self
.
config
.
get
(
'compile_shapes'
,
[])
}
"
)
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
...
...
@@ -216,6 +218,14 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
vae_encoder_out
,
text_encoder_output
)
@
ProfilingContext4DebugL2
(
"Run Text Encoder"
)
def
_run_input_encoder_local_animate
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
None
,
text_encoder_output
,
None
)
def
init_run
(
self
):
self
.
set_target_shape
()
self
.
get_video_segment_num
()
...
...
@@ -241,7 +251,7 @@ class DefaultRunner(BaseRunner):
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
self
.
end_run_segment
()
self
.
end_run_segment
(
segment_idx
)
self
.
end_run
()
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
...
...
@@ -304,7 +314,6 @@ class DefaultRunner(BaseRunner):
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
run_main
()
gen_video
=
self
.
process_images_after_vae_decoder
(
save_video
=
save_video
)
...
...
lightx2v/models/runners/wan/wan_animate_runner.py
0 → 100755
View file @
682037cd
import
gc
from
copy
import
deepcopy
import
cv2
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
decord
import
VideoReader
from
lightx2v.models.input_encoders.hf.animate.face_encoder
import
FaceEncoder
from
lightx2v.models.input_encoders.hf.animate.motion_encoder
import
Generator
from
lightx2v.models.networks.wan.animate_model
import
WanAnimateModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
load_weights
,
remove_substrings_from_keys
@
RUNNER_REGISTER
(
"wan2.2_animate"
)
class
WanAnimateRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
assert
self
.
config
.
task
==
"animate"
def
inputs_padding
(
self
,
array
,
target_len
):
idx
=
0
flip
=
False
target_array
=
[]
while
len
(
target_array
)
<
target_len
:
target_array
.
append
(
deepcopy
(
array
[
idx
]))
if
flip
:
idx
-=
1
else
:
idx
+=
1
if
idx
==
0
or
idx
==
len
(
array
)
-
1
:
flip
=
not
flip
return
target_array
[:
target_len
]
def
get_valid_len
(
self
,
real_len
,
clip_len
=
81
,
overlap
=
1
):
real_clip_len
=
clip_len
-
overlap
last_clip_num
=
(
real_len
-
overlap
)
%
real_clip_len
if
last_clip_num
==
0
:
extra
=
0
else
:
extra
=
real_clip_len
-
last_clip_num
target_len
=
real_len
+
extra
return
target_len
def
get_i2v_mask
(
self
,
lat_t
,
lat_h
,
lat_w
,
mask_len
=
1
,
mask_pixel_values
=
None
,
device
=
"cuda"
):
if
mask_pixel_values
is
None
:
msk
=
torch
.
zeros
(
1
,
(
lat_t
-
1
)
*
4
+
1
,
lat_h
,
lat_w
,
dtype
=
GET_DTYPE
(),
device
=
device
)
else
:
msk
=
mask_pixel_values
.
clone
()
msk
[:,
:
mask_len
]
=
1
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
return
msk
def
padding_resize
(
self
,
img_ori
,
height
=
512
,
width
=
512
,
padding_color
=
(
0
,
0
,
0
),
interpolation
=
cv2
.
INTER_LINEAR
,
):
ori_height
=
img_ori
.
shape
[
0
]
ori_width
=
img_ori
.
shape
[
1
]
channel
=
img_ori
.
shape
[
2
]
img_pad
=
np
.
zeros
((
height
,
width
,
channel
))
if
channel
==
1
:
img_pad
[:,
:,
0
]
=
padding_color
[
0
]
else
:
img_pad
[:,
:,
0
]
=
padding_color
[
0
]
img_pad
[:,
:,
1
]
=
padding_color
[
1
]
img_pad
[:,
:,
2
]
=
padding_color
[
2
]
if
(
ori_height
/
ori_width
)
>
(
height
/
width
):
new_width
=
int
(
height
/
ori_height
*
ori_width
)
img
=
cv2
.
resize
(
img_ori
,
(
new_width
,
height
),
interpolation
=
interpolation
)
padding
=
int
((
width
-
new_width
)
/
2
)
if
len
(
img
.
shape
)
==
2
:
img
=
img
[:,
:,
np
.
newaxis
]
img_pad
[:,
padding
:
padding
+
new_width
,
:]
=
img
else
:
new_height
=
int
(
width
/
ori_width
*
ori_height
)
img
=
cv2
.
resize
(
img_ori
,
(
width
,
new_height
),
interpolation
=
interpolation
)
padding
=
int
((
height
-
new_height
)
/
2
)
if
len
(
img
.
shape
)
==
2
:
img
=
img
[:,
:,
np
.
newaxis
]
img_pad
[
padding
:
padding
+
new_height
,
:,
:]
=
img
img_pad
=
np
.
uint8
(
img_pad
)
return
img_pad
def
prepare_source
(
self
,
src_pose_path
,
src_face_path
,
src_ref_path
):
pose_video_reader
=
VideoReader
(
src_pose_path
)
pose_len
=
len
(
pose_video_reader
)
pose_idxs
=
list
(
range
(
pose_len
))
cond_images
=
pose_video_reader
.
get_batch
(
pose_idxs
).
asnumpy
()
face_video_reader
=
VideoReader
(
src_face_path
)
face_len
=
len
(
face_video_reader
)
face_idxs
=
list
(
range
(
face_len
))
face_images
=
face_video_reader
.
get_batch
(
face_idxs
).
asnumpy
()
height
,
width
=
cond_images
[
0
].
shape
[:
2
]
refer_images
=
cv2
.
imread
(
src_ref_path
)[...,
::
-
1
]
refer_images
=
self
.
padding_resize
(
refer_images
,
height
=
height
,
width
=
width
)
return
cond_images
,
face_images
,
refer_images
def
prepare_source_for_replace
(
self
,
src_bg_path
,
src_mask_path
):
bg_video_reader
=
VideoReader
(
src_bg_path
)
bg_len
=
len
(
bg_video_reader
)
bg_idxs
=
list
(
range
(
bg_len
))
bg_images
=
bg_video_reader
.
get_batch
(
bg_idxs
).
asnumpy
()
mask_video_reader
=
VideoReader
(
src_mask_path
)
mask_len
=
len
(
mask_video_reader
)
mask_idxs
=
list
(
range
(
mask_len
))
mask_images
=
mask_video_reader
.
get_batch
(
mask_idxs
).
asnumpy
()
mask_images
=
mask_images
[:,
:,
:,
0
]
/
255
return
bg_images
,
mask_images
@
ProfilingContext4DebugL2
(
"Run Image Encoders"
)
def
run_image_encoders
(
self
,
conditioning_pixel_values
,
refer_t_pixel_values
,
bg_pixel_values
,
mask_pixel_values
,
face_pixel_values
,
):
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
refer_pixel_values
)
vae_encoder_out
,
pose_latents
=
self
.
run_vae_encoder
(
conditioning_pixel_values
,
refer_t_pixel_values
,
bg_pixel_values
,
mask_pixel_values
,
)
return
{
"image_encoder_output"
:
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encoder_out"
:
vae_encoder_out
,
"pose_latents"
:
pose_latents
,
"face_pixel_values"
:
face_pixel_values
}}
def
run_vae_encoder
(
self
,
conditioning_pixel_values
,
refer_t_pixel_values
,
bg_pixel_values
,
mask_pixel_values
,
):
H
,
W
=
self
.
refer_pixel_values
.
shape
[
-
2
],
self
.
refer_pixel_values
.
shape
[
-
1
]
pose_latents
=
self
.
vae_encoder
.
encode
(
conditioning_pixel_values
.
unsqueeze
(
0
))
# c t h w
ref_latents
=
self
.
vae_encoder
.
encode
(
self
.
refer_pixel_values
.
unsqueeze
(
1
).
unsqueeze
(
0
))
# c t h w
mask_ref
=
self
.
get_i2v_mask
(
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
1
)
y_ref
=
torch
.
concat
([
mask_ref
,
ref_latents
])
if
self
.
mask_reft_len
>
0
:
if
self
.
config
.
replace_flag
:
y_reft
=
self
.
vae_encoder
.
encode
(
torch
.
concat
(
[
refer_t_pixel_values
.
unsqueeze
(
2
)[
0
,
:,
:
self
.
mask_reft_len
],
bg_pixel_values
[:,
self
.
mask_reft_len
:],
],
dim
=
1
,
)
.
cuda
()
.
unsqueeze
(
0
)
)
mask_pixel_values
=
1
-
mask_pixel_values
mask_pixel_values
=
mask_pixel_values
.
permute
(
1
,
0
,
2
,
3
)
mask_pixel_values
=
F
.
interpolate
(
mask_pixel_values
,
size
=
(
H
//
8
,
W
//
8
),
mode
=
"nearest"
)
mask_pixel_values
=
mask_pixel_values
[:,
0
,
:,
:]
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
lat_t
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
self
.
mask_reft_len
,
mask_pixel_values
=
mask_pixel_values
.
unsqueeze
(
0
),
)
else
:
y_reft
=
self
.
vae_encoder
.
encode
(
torch
.
concat
(
[
torch
.
nn
.
functional
.
interpolate
(
refer_t_pixel_values
.
unsqueeze
(
2
)[
0
,
:,
:
self
.
mask_reft_len
].
cpu
(),
size
=
(
H
,
W
),
mode
=
"bicubic"
,
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
self
.
mask_reft_len
,
H
,
W
,
dtype
=
GET_DTYPE
()),
],
dim
=
1
,
)
.
cuda
()
.
unsqueeze
(
0
)
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
lat_t
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
self
.
mask_reft_len
)
else
:
if
self
.
config
.
replace_flag
:
mask_pixel_values
=
1
-
mask_pixel_values
mask_pixel_values
=
mask_pixel_values
.
permute
(
1
,
0
,
2
,
3
)
mask_pixel_values
=
F
.
interpolate
(
mask_pixel_values
,
size
=
(
H
//
8
,
W
//
8
),
mode
=
"nearest"
)
mask_pixel_values
=
mask_pixel_values
[:,
0
,
:,
:]
y_reft
=
self
.
vae_encoder
.
encode
(
bg_pixel_values
.
unsqueeze
(
0
))
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
lat_t
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
self
.
mask_reft_len
,
mask_pixel_values
=
mask_pixel_values
.
unsqueeze
(
0
),
)
else
:
y_reft
=
self
.
vae_encoder
.
encode
(
torch
.
zeros
(
1
,
3
,
self
.
config
.
target_video_length
-
self
.
mask_reft_len
,
H
,
W
,
dtype
=
GET_DTYPE
(),
device
=
"cuda"
))
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
lat_t
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
self
.
mask_reft_len
)
y_reft
=
torch
.
concat
([
msk_reft
,
y_reft
])
y
=
torch
.
concat
([
y_ref
,
y_reft
],
dim
=
1
)
return
y
,
pose_latents
def
prepare_input
(
self
):
src_pose_path
=
self
.
config
.
get
(
"src_pose_path"
,
None
)
src_face_path
=
self
.
config
.
get
(
"src_face_path"
,
None
)
src_ref_path
=
self
.
config
.
get
(
"src_ref_images"
,
None
)
self
.
cond_images
,
self
.
face_images
,
self
.
refer_images
=
self
.
prepare_source
(
src_pose_path
,
src_face_path
,
src_ref_path
)
self
.
refer_pixel_values
=
torch
.
tensor
(
self
.
refer_images
/
127.5
-
1
,
dtype
=
GET_DTYPE
(),
device
=
"cuda"
).
permute
(
2
,
0
,
1
)
# chw
self
.
real_frame_len
=
len
(
self
.
cond_images
)
target_len
=
self
.
get_valid_len
(
self
.
real_frame_len
,
self
.
config
.
target_video_length
,
overlap
=
self
.
config
.
get
(
"refert_num"
,
1
),
)
logger
.
info
(
"real frames: {} target frames: {}"
.
format
(
self
.
real_frame_len
,
target_len
))
self
.
cond_images
=
self
.
inputs_padding
(
self
.
cond_images
,
target_len
)
self
.
face_images
=
self
.
inputs_padding
(
self
.
face_images
,
target_len
)
if
self
.
config
.
get
(
"replace_flag"
,
False
):
src_bg_path
=
self
.
config
.
get
(
"src_bg_path"
)
src_mask_path
=
self
.
config
.
get
(
"src_mask_path"
)
self
.
bg_images
,
self
.
mask_images
=
self
.
prepare_source_for_replace
(
src_bg_path
,
src_mask_path
)
self
.
bg_images
=
self
.
inputs_padding
(
self
.
bg_images
,
target_len
)
self
.
mask_images
=
self
.
inputs_padding
(
self
.
mask_images
,
target_len
)
def
get_video_segment_num
(
self
):
total_frames
=
len
(
self
.
cond_images
)
self
.
move_frames
=
self
.
config
.
target_video_length
-
self
.
config
.
refert_num
if
total_frames
<=
self
.
config
.
target_video_length
:
self
.
video_segment_num
=
1
else
:
self
.
video_segment_num
=
1
+
(
total_frames
-
self
.
config
.
target_video_length
+
self
.
move_frames
-
1
)
//
self
.
move_frames
def
init_run
(
self
):
self
.
all_out_frames
=
[]
self
.
prepare_input
()
super
().
init_run
()
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
images
=
self
.
vae_decoder
.
decode
(
latents
[:,
1
:].
to
(
GET_DTYPE
()))
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
images
def
init_run_segment
(
self
,
segment_idx
):
start
=
segment_idx
*
self
.
move_frames
end
=
start
+
self
.
config
.
target_video_length
if
start
==
0
:
self
.
mask_reft_len
=
0
else
:
self
.
mask_reft_len
=
self
.
config
.
refert_num
conditioning_pixel_values
=
torch
.
tensor
(
np
.
stack
(
self
.
cond_images
[
start
:
end
])
/
127.5
-
1
,
device
=
"cuda"
,
dtype
=
GET_DTYPE
(),
).
permute
(
3
,
0
,
1
,
2
)
# c t h w
face_pixel_values
=
torch
.
tensor
(
np
.
stack
(
self
.
face_images
[
start
:
end
])
/
127.5
-
1
,
device
=
"cuda"
,
dtype
=
GET_DTYPE
(),
).
permute
(
0
,
3
,
1
,
2
)
# thwc->tchw
if
start
==
0
:
height
,
width
=
self
.
refer_images
.
shape
[:
2
]
refer_t_pixel_values
=
torch
.
zeros
(
3
,
self
.
config
.
refert_num
,
height
,
width
,
device
=
"cuda"
,
dtype
=
GET_DTYPE
(),
)
# c t h w
else
:
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
.
refert_num
:].
transpose
(
0
,
1
).
clone
().
detach
()
# c t h w
bg_pixel_values
,
mask_pixel_values
=
None
,
None
if
self
.
config
.
replace_flag
:
bg_pixel_values
=
torch
.
tensor
(
np
.
stack
(
self
.
bg_images
[
start
:
end
])
/
127.5
-
1
,
device
=
"cuda"
,
dtype
=
GET_DTYPE
(),
).
permute
(
3
,
0
,
1
,
2
)
# c t h w,
mask_pixel_values
=
torch
.
tensor
(
np
.
stack
(
self
.
mask_images
[
start
:
end
])[:,
:,
:,
None
],
device
=
"cuda"
,
dtype
=
GET_DTYPE
(),
).
permute
(
3
,
0
,
1
,
2
)
# c t h w,
self
.
inputs
.
update
(
self
.
run_image_encoders
(
conditioning_pixel_values
,
refer_t_pixel_values
,
bg_pixel_values
,
mask_pixel_values
,
face_pixel_values
,
)
)
if
start
!=
0
:
self
.
model
.
scheduler
.
reset
()
def
end_run_segment
(
self
,
segment_idx
):
if
segment_idx
!=
0
:
self
.
gen_video
=
self
.
gen_video
[:,
:,
self
.
config
[
"refert_num"
]
:]
self
.
all_out_frames
.
append
(
self
.
gen_video
.
cpu
())
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
self
.
gen_video
=
torch
.
cat
(
self
.
all_out_frames
,
dim
=
2
)[:,
:,
:
self
.
real_frame_len
]
del
self
.
all_out_frames
gc
.
collect
()
super
().
process_images_after_vae_decoder
(
save_video
)
def
set_target_shape
(
self
):
self
.
config
.
target_video_length
=
self
.
config
.
target_video_length
self
.
config
.
lat_h
=
self
.
refer_pixel_values
.
shape
[
-
2
]
//
8
self
.
config
.
lat_w
=
self
.
refer_pixel_values
.
shape
[
-
1
]
//
8
self
.
config
.
lat_t
=
self
.
config
.
target_video_length
//
4
+
1
self
.
config
.
target_shape
=
[
16
,
self
.
config
.
lat_t
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
]
def
run_image_encoder
(
self
,
img
):
# CHW
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
.
unsqueeze
(
0
)]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
clip_encoder_out
def
load_transformer
(
self
):
model
=
WanAnimateModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
)
motion_encoder
,
face_encoder
=
self
.
load_encoder
()
model
.
set_animate_encoders
(
motion_encoder
,
face_encoder
)
return
model
def
load_encoder
(
self
):
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
())
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
())
motion_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"motion_encoder"
]),
"motion_encoder."
)
face_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"face_encoder"
]),
"face_encoder."
)
motion_encoder
.
load_state_dict
(
motion_weight_dict
)
face_encoder
.
load_state_dict
(
face_weight_dict
)
if
not
self
.
config
[
"cpu_offload"
]:
motion_encoder
=
motion_encoder
.
cuda
()
face_encoder
=
face_encoder
.
cuda
()
return
motion_encoder
,
face_encoder
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
682037cd
...
...
@@ -630,7 +630,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
@
ProfilingContext4DebugL1
(
"End run segment"
)
def
end_run_segment
(
self
):
def
end_run_segment
(
self
,
segment_idx
):
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
useful_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
video_seg
=
self
.
gen_video
[:,
:,
:
useful_length
].
cpu
()
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
682037cd
...
...
@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner):
def
load_image_encoder
(
self
):
image_encoder
=
None
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
# offload config
clip_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
if
clip_offload
:
...
...
@@ -154,7 +154,7 @@ class WanRunner(DefaultRunner):
"dtype"
:
GET_DTYPE
(),
"load_from_rank0"
:
self
.
config
.
get
(
"load_from_rank0"
,
False
),
}
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"vace"
]:
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
]:
return
None
else
:
return
self
.
vae_cls
(
**
vae_config
)
...
...
@@ -347,7 +347,7 @@ class WanRunner(DefaultRunner):
def
set_target_shape
(
self
):
num_channels_latents
=
self
.
config
.
get
(
"num_channels_latents"
,
16
)
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
]:
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
682037cd
...
...
@@ -117,7 +117,9 @@ class WanScheduler(BaseScheduler):
x0_pred
=
sample
-
sigma_t
*
model_output
return
x0_pred
def
reset
(
self
):
def
reset
(
self
,
step_index
=
None
):
if
step_index
is
not
None
:
self
.
step_index
=
step_index
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
...
...
lightx2v/utils/registry_factory.py
100644 → 100755
View file @
682037cd
...
...
@@ -50,7 +50,6 @@ RMS_WEIGHT_REGISTER = Register()
LN_WEIGHT_REGISTER
=
Register
()
CONV3D_WEIGHT_REGISTER
=
Register
()
CONV2D_WEIGHT_REGISTER
=
Register
()
TENSOR_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
lightx2v/utils/utils.py
View file @
682037cd
...
...
@@ -292,6 +292,13 @@ def save_to_video(
raise
ValueError
(
f
"Unknown save method:
{
method
}
"
)
def
remove_substrings_from_keys
(
original_dict
,
substr
):
new_dict
=
{}
for
key
,
value
in
original_dict
.
items
():
new_dict
[
key
.
replace
(
substr
,
""
)]
=
value
return
new_dict
def
find_torch_model_path
(
config
,
ckpt_config_key
=
None
,
filename
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
,
"distill_models"
,
"distill_fp8"
,
"distill_int8"
]):
if
ckpt_config_key
and
config
.
get
(
ckpt_config_key
,
None
)
is
not
None
:
return
config
.
get
(
ckpt_config_key
)
...
...
@@ -357,50 +364,72 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise
FileNotFoundError
(
f
"No GGUF model files (.gguf) found.
\n
Please download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
load_safetensors
(
in_path
,
remove_key
):
def
load_safetensors
(
in_path
,
remove_key
=
None
,
include_keys
=
None
):
"""加载safetensors文件或目录,支持按key包含筛选或排除"""
include_keys
=
include_keys
or
[]
if
os
.
path
.
isdir
(
in_path
):
return
load_safetensors_from_dir
(
in_path
,
remove_key
)
return
load_safetensors_from_dir
(
in_path
,
remove_key
,
include_keys
)
elif
os
.
path
.
isfile
(
in_path
):
return
load_safetensors_from_path
(
in_path
,
remove_key
)
return
load_safetensors_from_path
(
in_path
,
remove_key
,
include_keys
)
else
:
raise
ValueError
(
f
"
{
in_path
}
does not exist"
)
def
load_safetensors_from_path
(
in_path
,
remove_key
):
def
load_safetensors_from_path
(
in_path
,
remove_key
=
None
,
include_keys
=
None
):
"""从单个safetensors文件加载权重,支持按key筛选"""
include_keys
=
include_keys
or
[]
tensors
=
{}
with
safetensors
.
safe_open
(
in_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
if
remove_key
not
in
key
:
tensors
[
key
]
=
f
.
get_tensor
(
key
)
# 优先处理include_keys:如果非空,只保留包含任意指定key的条目
if
include_keys
:
if
any
(
inc_key
in
key
for
inc_key
in
include_keys
):
tensors
[
key
]
=
f
.
get_tensor
(
key
)
# 否则使用remove_key排除
else
:
if
not
(
remove_key
and
remove_key
in
key
):
tensors
[
key
]
=
f
.
get_tensor
(
key
)
return
tensors
def
load_safetensors_from_dir
(
in_dir
,
remove_key
):
def
load_safetensors_from_dir
(
in_dir
,
remove_key
=
None
,
include_keys
=
None
):
"""从目录加载所有safetensors文件,支持按key筛选"""
include_keys
=
include_keys
or
[]
tensors
=
{}
safetensors
=
os
.
listdir
(
in_dir
)
safetensors
=
[
f
for
f
in
safetensors
if
f
.
endswith
(
".safetensors"
)]
for
f
in
safetensors
:
tensors
.
update
(
load_safetensors_from_path
(
os
.
path
.
join
(
in_dir
,
f
),
remove_key
))
safetensors
_files
=
os
.
listdir
(
in_dir
)
safetensors
_files
=
[
f
for
f
in
safetensors
_files
if
f
.
endswith
(
".safetensors"
)]
for
f
in
safetensors
_files
:
tensors
.
update
(
load_safetensors_from_path
(
os
.
path
.
join
(
in_dir
,
f
),
remove_key
,
include_keys
))
return
tensors
def
load_pt_safetensors
(
in_path
,
remove_key
):
def
load_pt_safetensors
(
in_path
,
remove_key
=
None
,
include_keys
=
None
):
"""加载pt/pth或safetensors权重,支持按key筛选"""
include_keys
=
include_keys
or
[]
ext
=
os
.
path
.
splitext
(
in_path
)[
-
1
]
if
ext
in
(
".pt"
,
".pth"
,
".tar"
):
state_dict
=
torch
.
load
(
in_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
for
key
in
list
(
state_dict
.
keys
()):
if
remove_key
and
remove_key
in
key
:
state_dict
.
pop
(
key
)
# 处理筛选逻辑
keys_to_keep
=
[]
for
key
in
state_dict
.
keys
():
if
include_keys
:
if
any
(
inc_key
in
key
for
inc_key
in
include_keys
):
keys_to_keep
.
append
(
key
)
else
:
if
not
(
remove_key
and
remove_key
in
key
):
keys_to_keep
.
append
(
key
)
# 只保留符合条件的key
state_dict
=
{
k
:
state_dict
[
k
]
for
k
in
keys_to_keep
}
else
:
state_dict
=
load_safetensors
(
in_path
,
remove_key
)
state_dict
=
load_safetensors
(
in_path
,
remove_key
,
include_keys
)
return
state_dict
def
load_weights
(
checkpoint_path
,
cpu_offload
=
False
,
remove_key
=
None
,
load_from_rank0
=
False
):
def
load_weights
(
checkpoint_path
,
cpu_offload
=
False
,
remove_key
=
None
,
load_from_rank0
=
False
,
include_keys
=
None
):
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
# Single GPU mode
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
cpu_weight_dict
=
load_pt_safetensors
(
checkpoint_path
,
remove_key
)
cpu_weight_dict
=
load_pt_safetensors
(
checkpoint_path
,
remove_key
,
include_keys
)
return
cpu_weight_dict
# Multi-GPU mode
...
...
@@ -413,9 +442,6 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_
if
is_weight_loader
:
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
cpu_weight_dict
=
load_pt_safetensors
(
checkpoint_path
,
remove_key
)
for
key
in
list
(
cpu_weight_dict
.
keys
()):
if
remove_key
and
remove_key
in
key
:
cpu_weight_dict
.
pop
(
key
)
meta_dict
=
{}
if
is_weight_loader
:
...
...
scripts/wan22/run_wan22_animate.sh
0 → 100755
View file @
682037cd
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
7
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
python
-m
lightx2v.infer
\
--model_cls
wan2.2_animate
\
--task
animate
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/wan22/wan_animate_replace.json
\
--prompt
"视频中的人在做动作"
\
--negative_prompt
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan22_animate.mp4
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment