Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
04812de2
Unverified
Commit
04812de2
authored
Sep 29, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Sep 29, 2025
Browse files
Refactor Config System (#338)
parent
6a658f42
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
329 additions
and
417 deletions
+329
-417
lightx2v/models/networks/wan/sf_model.py
lightx2v/models/networks/wan/sf_model.py
+2
-2
lightx2v/models/networks/wan/weights/pre_weights.py
lightx2v/models/networks/wan/weights/pre_weights.py
+4
-4
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+3
-3
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+2
-23
lightx2v/models/runners/cogvideox/cogvidex_runner.py
lightx2v/models/runners/cogvideox/cogvidex_runner.py
+0
-9
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+69
-59
lightx2v/models/runners/hunyuan/hunyuan_runner.py
lightx2v/models/runners/hunyuan/hunyuan_runner.py
+0
-4
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+1
-1
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+1
-1
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+74
-124
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+1
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+19
-19
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+58
-66
lightx2v/models/runners/wan/wan_sf_runner.py
lightx2v/models/runners/wan/wan_sf_runner.py
+32
-36
lightx2v/models/runners/wan/wan_vace_runner.py
lightx2v/models/runners/wan/wan_vace_runner.py
+4
-4
lightx2v/models/schedulers/scheduler.py
lightx2v/models/schedulers/scheduler.py
+2
-2
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+18
-18
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+10
-10
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+19
-22
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+10
-9
No files found.
lightx2v/models/networks/wan/sf_model.py
View file @
04812de2
...
...
@@ -14,8 +14,8 @@ class WanSFModel(WanModel):
self
.
to_cuda
()
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
sf_confg
=
self
.
config
.
sf_config
file_path
=
os
.
path
.
join
(
self
.
config
.
sf_model_path
,
f
"checkpoints/self_forcing_
{
sf_confg
.
sf_type
}
.pt"
)
sf_confg
=
self
.
config
[
"
sf_config
"
]
file_path
=
os
.
path
.
join
(
self
.
config
[
"
sf_model_path
"
]
,
f
"checkpoints/self_forcing_
{
sf_confg
[
'
sf_type
'
]
}
.pt"
)
_weight_dict
=
torch
.
load
(
file_path
)[
"generator_ema"
]
weight_dict
=
{}
for
k
,
v
in
_weight_dict
.
items
():
...
...
lightx2v/models/networks/wan/weights/pre_weights.py
View file @
04812de2
...
...
@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
)
if
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
config
.
get
(
"use_image_encoder"
,
True
):
if
config
[
"
task
"
]
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
"proj_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
),
...
...
@@ -58,7 +58,7 @@ class WanPreWeights(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
),
)
if
config
.
model_cls
==
"wan2.1_distill"
and
config
.
get
(
"enable_dynamic_cfg"
,
False
):
if
config
[
"
model_cls
"
]
==
"wan2.1_distill"
and
config
.
get
(
"enable_dynamic_cfg"
,
False
):
self
.
add_module
(
"cfg_cond_proj_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_embedding.linear_1.weight"
,
"guidance_embedding.linear_1.bias"
),
...
...
@@ -68,12 +68,12 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_embedding.linear_2.weight"
,
"guidance_embedding.linear_2.bias"
),
)
if
config
.
task
==
"flf2v"
and
config
.
get
(
"use_image_encoder"
,
True
):
if
config
[
"
task
"
]
==
"flf2v"
and
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
"emb_pos"
,
TENSOR_REGISTER
[
"Default"
](
f
"img_emb.emb_pos"
),
)
if
config
.
task
==
"animate"
:
if
config
[
"
task
"
]
==
"animate"
:
self
.
add_module
(
"pose_patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"pose_patch_embedding.weight"
,
"pose_patch_embedding.bias"
,
stride
=
self
.
patch_size
),
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
04812de2
...
...
@@ -60,7 +60,7 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
self
.
config
.
dit_quantized_ckpt
,
f
"block_
{
block_index
}
.safetensors"
)
lazy_load_path
=
os
.
path
.
join
(
self
.
config
[
"
dit_quantized_ckpt
"
]
,
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
self
.
lazy_load_file
=
None
...
...
@@ -197,7 +197,7 @@ class WanSelfAttention(WeightModule):
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
if
self
.
config
[
"seq_parallel"
]:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
.
parallel
.
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
parallel
"
]
.
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
if
self
.
quant_method
in
[
"advanced_ptq"
]:
self
.
add_module
(
...
...
@@ -296,7 +296,7 @@ class WanCrossAttention(WeightModule):
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
[
"
task
"
]
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
...
...
lightx2v/models/runners/base_runner.py
View file @
04812de2
...
...
@@ -3,8 +3,6 @@ from abc import ABC
import
torch
import
torch.distributed
as
dist
from
lightx2v.utils.utils
import
save_videos_grid
class
BaseRunner
(
ABC
):
"""Abstract base class for all Runners
...
...
@@ -15,6 +13,7 @@ class BaseRunner(ABC):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
vae_encoder_need_img_original
=
False
self
.
input_info
=
None
def
load_transformer
(
self
):
"""Load transformer model
...
...
@@ -100,26 +99,6 @@ class BaseRunner(ABC):
"""Initialize scheduler"""
pass
def
set_target_shape
(
self
):
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return
{}
def
save_video_func
(
self
,
images
):
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid
(
images
,
self
.
config
.
get
(
"save_video_path"
,
"./output.mp4"
),
n_rows
=
1
,
fps
=
self
.
config
.
get
(
"fps"
,
8
))
def
load_vae_decoder
(
self
):
"""Load VAE decoder
...
...
@@ -146,7 +125,7 @@ class BaseRunner(ABC):
pass
def
end_run_segment
(
self
,
segment_idx
=
None
):
pass
self
.
gen_video_final
=
self
.
gen_video
def
end_run
(
self
):
pass
...
...
lightx2v/models/runners/cogvideox/cogvidex_runner.py
View file @
04812de2
import
imageio
import
numpy
as
np
from
lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model
import
T5EncoderModel_v1_1_xxl
from
lightx2v.models.networks.cogvideox.model
import
CogvideoxModel
from
lightx2v.models.runners.default_runner
import
DefaultRunner
...
...
@@ -72,9 +69,3 @@ class CogvideoxRunner(DefaultRunner):
)
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
def
save_video_func
(
self
,
images
):
with
imageio
.
get_writer
(
self
.
config
.
save_video_path
,
fps
=
16
)
as
writer
:
for
pil_image
in
images
:
frame_np
=
np
.
array
(
pil_image
,
dtype
=
np
.
uint8
)
writer
.
append_data
(
frame_np
)
lightx2v/models/runners/default_runner.py
View file @
04812de2
...
...
@@ -22,13 +22,13 @@ class DefaultRunner(BaseRunner):
super
().
__init__
(
config
)
self
.
has_prompt_enhancer
=
False
self
.
progress_callback
=
None
if
self
.
config
.
task
==
"t2v"
and
self
.
config
.
get
(
"sub_servers"
,
{}).
get
(
"prompt_enhancer"
)
is
not
None
:
if
self
.
config
[
"
task
"
]
==
"t2v"
and
self
.
config
.
get
(
"sub_servers"
,
{}).
get
(
"prompt_enhancer"
)
is
not
None
:
self
.
has_prompt_enhancer
=
True
if
not
self
.
check_sub_servers
(
"prompt_enhancer"
):
self
.
has_prompt_enhancer
=
False
logger
.
warning
(
"No prompt enhancer server available, disable prompt enhancer."
)
if
not
self
.
has_prompt_enhancer
:
self
.
config
.
use_prompt_enhancer
=
False
self
.
config
[
"
use_prompt_enhancer
"
]
=
False
self
.
set_init_device
()
self
.
init_scheduler
()
...
...
@@ -49,12 +49,15 @@ class DefaultRunner(BaseRunner):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_vace
elif
self
.
config
[
"task"
]
==
"animate"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_animate
elif
self
.
config
[
"task"
]
==
"s2v"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_s2v
self
.
config
.
lock
()
# lock config to avoid modification
if
self
.
config
.
get
(
"compile"
,
False
):
logger
.
info
(
f
"[Compile] Compile all shapes:
{
self
.
config
.
get
(
'compile_shapes'
,
[])
}
"
)
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
def
set_init_device
(
self
):
if
self
.
config
.
cpu_offload
:
if
self
.
config
[
"
cpu_offload
"
]
:
self
.
init_device
=
torch
.
device
(
"cpu"
)
else
:
self
.
init_device
=
torch
.
device
(
"cuda"
)
...
...
@@ -96,21 +99,23 @@ class DefaultRunner(BaseRunner):
return
len
(
available_servers
)
>
0
def
set_inputs
(
self
,
inputs
):
self
.
config
[
"prompt"
]
=
inputs
.
get
(
"prompt"
,
""
)
self
.
config
[
"use_prompt_enhancer"
]
=
False
if
self
.
has_prompt_enhancer
:
self
.
config
[
"use_prompt_enhancer"
]
=
inputs
.
get
(
"use_prompt_enhancer"
,
False
)
# Reset use_prompt_enhancer from clinet side.
self
.
config
[
"negative_prompt"
]
=
inputs
.
get
(
"negative_prompt"
,
""
)
self
.
config
[
"image_path"
]
=
inputs
.
get
(
"image_path"
,
""
)
self
.
config
[
"save_video_path"
]
=
inputs
.
get
(
"save_video_path"
,
""
)
self
.
config
[
"infer_steps"
]
=
inputs
.
get
(
"infer_steps"
,
self
.
config
.
get
(
"infer_steps"
,
5
))
self
.
config
[
"target_video_length"
]
=
inputs
.
get
(
"target_video_length"
,
self
.
config
.
get
(
"target_video_length"
,
81
))
self
.
config
[
"seed"
]
=
inputs
.
get
(
"seed"
,
self
.
config
.
get
(
"seed"
,
42
))
self
.
config
[
"audio_path"
]
=
inputs
.
get
(
"audio_path"
,
""
)
# for wan-audio
self
.
config
[
"video_duration"
]
=
inputs
.
get
(
"video_duration"
,
5
)
# for wan-audio
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
self
.
input_info
.
seed
=
inputs
.
get
(
"seed"
,
42
)
self
.
input_info
.
prompt
=
inputs
.
get
(
"prompt"
,
""
)
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
input_info
.
prompt_enhanced
=
inputs
.
get
(
"prompt_enhanced"
,
""
)
self
.
input_info
.
negative_prompt
=
inputs
.
get
(
"negative_prompt"
,
""
)
if
"image_path"
in
self
.
input_info
.
__dataclass_fields__
:
self
.
input_info
.
image_path
=
inputs
.
get
(
"image_path"
,
""
)
if
"audio_path"
in
self
.
input_info
.
__dataclass_fields__
:
self
.
input_info
.
audio_path
=
inputs
.
get
(
"audio_path"
,
""
)
if
"video_path"
in
self
.
input_info
.
__dataclass_fields__
:
self
.
input_info
.
video_path
=
inputs
.
get
(
"video_path"
,
""
)
self
.
input_info
.
save_result_path
=
inputs
.
get
(
"save_result_path"
,
""
)
def
set_config
(
self
,
config_modify
):
logger
.
info
(
f
"modify config:
{
config_modify
}
"
)
with
self
.
config
.
temporarily_unlocked
():
self
.
config
.
update
(
config_modify
)
def
set_progress_callback
(
self
,
callback
):
self
.
progress_callback
=
callback
...
...
@@ -146,6 +151,7 @@ class DefaultRunner(BaseRunner):
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
del
self
.
inputs
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
...
...
@@ -162,23 +168,24 @@ class DefaultRunner(BaseRunner):
else
:
img_ori
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
self
.
input_info
.
original_size
=
img_ori
.
size
return
img
,
img_ori
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
img
,
img_ori
=
self
.
read_image_input
(
self
.
input_info
.
image_path
)
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img_ori
if
self
.
vae_encoder_need_img_original
else
img
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
img
)
vae_encode_out
,
latent_shape
=
self
.
run_vae_encoder
(
img_ori
if
self
.
vae_encoder_need_img_original
else
img
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
self
.
input_info
.
latent_shape
=
self
.
get_latent_shape_with_target_hw
(
self
.
config
[
"target_height"
],
self
.
config
[
"target_width"
])
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -188,22 +195,21 @@ class DefaultRunner(BaseRunner):
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_flf2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
first_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
last_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"last_frame_path"
])
first_frame
,
_
=
self
.
read_image_input
(
self
.
input_info
.
image_path
)
last_frame
,
_
=
self
.
read_image_input
(
self
.
input_info
.
last_frame_path
)
clip_encoder_out
=
self
.
run_image_encoder
(
first_frame
,
last_frame
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
first_frame
,
last_frame
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
first_frame
)
vae_encode_out
,
latent_shape
=
self
.
run_vae_encoder
(
first_frame
,
last_frame
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_vace
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
src_video
=
self
.
config
.
get
(
"src_video"
,
None
)
src_mask
=
self
.
config
.
get
(
"src_mask"
,
None
)
src_ref_images
=
self
.
config
.
get
(
"src_ref_images"
,
None
)
src_video
=
self
.
input_info
.
src_video
src_mask
=
self
.
input_info
.
src_mask
src_ref_images
=
self
.
input_info
.
src_ref_images
src_video
,
src_mask
,
src_ref_images
=
self
.
prepare_source
(
[
src_video
],
[
src_mask
],
...
...
@@ -212,34 +218,38 @@ class DefaultRunner(BaseRunner):
)
self
.
src_ref_images
=
src_ref_images
vae_encoder_out
=
self
.
run_vae_encoder
(
src_video
,
src_ref_images
,
src_mask
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
)
vae_encoder_out
,
latent_shape
=
self
.
run_vae_encoder
(
src_video
,
src_ref_images
,
src_mask
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
vae_encoder_out
,
text_encoder_output
)
@
ProfilingContext4DebugL2
(
"Run Text Encoder"
)
def
_run_input_encoder_local_animate
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
None
,
text_encoder_output
,
None
)
def
_run_input_encoder_local_s2v
(
self
):
pass
def
init_run
(
self
):
self
.
set_target_shape
()
self
.
gen_video_final
=
None
self
.
get_video_segment_num
()
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
model
.
scheduler
.
prepare
(
seed
=
self
.
input_info
.
seed
,
latent_shape
=
self
.
input_info
.
latent_shape
,
image_encoder_output
=
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
()
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
for
segment_idx
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"🔄 start segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
with
ProfilingContext4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
...
...
@@ -252,7 +262,9 @@ class DefaultRunner(BaseRunner):
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
self
.
end_run_segment
(
segment_idx
)
gen_video_final
=
self
.
process_images_after_vae_decoder
()
self
.
end_run
()
return
{
"video"
:
gen_video_final
}
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
...
...
@@ -281,20 +293,22 @@ class DefaultRunner(BaseRunner):
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
return
enhanced_prompt
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
self
.
gen_video
=
vae_to_comfyui_image
(
self
.
gen_video
)
def
process_images_after_vae_decoder
(
self
):
self
.
gen_video
_final
=
vae_to_comfyui_image
(
self
.
gen_video
_final
)
if
"video_frame_interpolation"
in
self
.
config
:
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
self
.
config
.
get
(
'fps'
,
16
)
}
to
{
target_fps
}
"
)
self
.
gen_video
=
self
.
vfi_model
.
interpolate_frames
(
self
.
gen_video
,
self
.
gen_video
_final
=
self
.
vfi_model
.
interpolate_frames
(
self
.
gen_video
_final
,
source_fps
=
self
.
config
.
get
(
"fps"
,
16
),
target_fps
=
target_fps
,
)
if
save_video
:
if
self
.
input_info
.
return_result_tensor
:
return
{
"video"
:
self
.
gen_video_final
}
elif
self
.
input_info
.
save_result_path
is
not
None
:
if
"video_frame_interpolation"
in
self
.
config
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
):
fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
else
:
...
...
@@ -303,22 +317,18 @@ class DefaultRunner(BaseRunner):
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"🎬 Start to save video 🎬"
)
save_to_video
(
self
.
gen_video
,
self
.
config
.
save_video_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
logger
.
info
(
f
"✅ Video saved successfully to:
{
self
.
config
.
save_video_path
}
✅"
)
if
self
.
config
.
get
(
"return_video"
,
False
):
return
{
"video"
:
self
.
gen_video
}
return
{
"video"
:
None
}
save_to_video
(
self
.
gen_video_final
,
self
.
input_info
.
save_result_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
logger
.
info
(
f
"✅ Video saved successfully to:
{
self
.
input_info
.
save_result_path
}
✅"
)
return
{
"video"
:
None
}
def
run_pipeline
(
self
,
input_info
):
self
.
input_info
=
input_info
def
run_pipeline
(
self
,
save_video
=
True
):
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"
prompt_enhanced
"
]
=
self
.
post_prompt_enhancer
()
self
.
input_info
.
prompt_enhanced
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
run_main
()
gen_video
=
self
.
process_images_after_vae_decoder
(
save_video
=
save_video
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gen_video_final
=
self
.
run_main
()
return
gen_video
return
gen_video
_final
lightx2v/models/runners/hunyuan/hunyuan_runner.py
View file @
04812de2
...
...
@@ -14,7 +14,6 @@ from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from
lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae
import
HunyuanVAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_videos_grid
@
RUNNER_REGISTER
(
"hunyuan"
)
...
...
@@ -152,6 +151,3 @@ class HunyuanRunner(DefaultRunner):
int
(
self
.
config
.
target_width
)
//
vae_scale_factor
,
)
return
{
"target_height"
:
self
.
config
.
target_height
,
"target_width"
:
self
.
config
.
target_width
,
"target_shape"
:
self
.
config
.
target_shape
}
def
save_video_func
(
self
,
images
):
save_videos_grid
(
images
,
self
.
config
.
save_video_path
,
fps
=
self
.
config
.
get
(
"fps"
,
24
))
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
04812de2
...
...
@@ -204,7 +204,7 @@ class QwenImageRunner(DefaultRunner):
images
=
self
.
run_vae_decoder
(
latents
,
generator
)
image
=
images
[
0
]
image
.
save
(
f
"
{
self
.
config
.
save_
video
_path
}
"
)
image
.
save
(
f
"
{
self
.
config
.
save_
result
_path
}
"
)
del
latents
,
generator
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/runners/wan/wan_animate_runner.py
View file @
04812de2
...
...
@@ -334,7 +334,7 @@ class WanAnimateRunner(WanRunner):
)
if
start
!=
0
:
self
.
model
.
scheduler
.
reset
()
self
.
model
.
scheduler
.
reset
(
self
.
input_info
.
seed
,
self
.
input_info
.
latent_shape
)
def
end_run_segment
(
self
,
segment_idx
):
if
segment_idx
!=
0
:
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
04812de2
import
gc
import
json
import
os
import
warnings
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
...
...
@@ -337,16 +337,14 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Initialize consistency model scheduler"""
self
.
scheduler
=
EulerScheduler
(
self
.
config
)
def
read_audio_input
(
self
):
def
read_audio_input
(
self
,
audio_path
):
"""Read audio input - handles both single and multi-person scenarios"""
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
# Get audio files from person objects or legacy format
audio_files
=
self
.
_get_audio_files_from_config
()
if
not
audio_files
:
return
[],
0
audio_files
,
mask_files
=
self
.
get_audio_files_from_audio_path
(
audio_path
)
# Load audio based on single or multi-person mode
if
len
(
audio_files
)
==
1
:
...
...
@@ -355,8 +353,6 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
audio_array
=
self
.
_audio_processor
.
load_multi_person_audio
(
audio_files
)
self
.
config
.
audio_num
=
audio_array
.
size
(
0
)
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
audio_len
=
int
(
audio_array
.
shape
[
1
]
/
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
...
...
@@ -364,60 +360,35 @@ class WanAudioRunner(WanRunner): # type:ignore
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
self
.
config
.
get
(
"target_video_length"
,
81
),
self
.
prev_frame_length
)
return
audio_array
.
size
(
0
),
audio_segments
,
expected_frames
def
_get_audio_files_from_config
(
self
):
talk_objects
=
self
.
config
.
get
(
"talk_objects"
)
if
talk_objects
:
audio_files
=
[]
for
idx
,
person
in
enumerate
(
talk_objects
):
audio_path
=
person
.
get
(
"audio"
)
if
audio_path
and
Path
(
audio_path
).
is_file
():
audio_files
.
append
(
str
(
audio_path
))
else
:
logger
.
warning
(
f
"Person
{
idx
}
audio file
{
audio_path
}
does not exist or not specified"
)
if
audio_files
:
logger
.
info
(
f
"Loaded
{
len
(
audio_files
)
}
audio files from talk_objects"
)
return
audio_files
audio_path
=
self
.
config
.
get
(
"audio_path"
)
if
audio_path
:
return
[
audio_path
]
logger
.
error
(
"config audio_path or talk_objects is not specified"
)
return
[]
# Mask latent for multi-person s2v
if
mask_files
is
not
None
:
mask_latents
=
[
self
.
process_single_mask
(
mask_file
)
for
mask_file
in
mask_files
]
mask_latents
=
torch
.
cat
(
mask_latents
,
dim
=
0
)
else
:
mask_latents
=
None
def
read_person_mask
(
self
):
mask_files
=
self
.
_get_mask_files_from_config
()
if
not
mask_files
:
return
None
return
audio_segments
,
expected_frames
,
mask_latents
,
len
(
audio_files
)
mask_latents
=
[]
for
mask_file
in
mask_files
:
mask_latent
=
self
.
_process_single_mask
(
mask_file
)
mask_latents
.
append
(
mask_latent
)
def
get_audio_files_from_audio_path
(
self
,
audio_path
):
if
os
.
path
.
isdir
(
audio_path
):
audio_files
=
[]
mask_files
=
[]
logger
.
info
(
f
"audio_path is a directory, loading config.json from
{
audio_path
}
"
)
audio_config_path
=
os
.
path
.
join
(
audio_path
,
"config.json"
)
assert
os
.
path
.
exists
(
audio_config_path
),
"config.json not found in audio_path"
with
open
(
audio_config_path
,
"r"
)
as
f
:
audio_config
=
json
.
load
(
f
)
for
talk_object
in
audio_config
[
"talk_objects"
]:
audio_files
.
append
(
os
.
path
.
join
(
audio_path
,
talk_object
[
"audio"
]))
mask_files
.
append
(
os
.
path
.
join
(
audio_path
,
talk_object
[
"mask"
]))
else
:
logger
.
info
(
f
"audio_path is a file without mask:
{
audio_path
}
"
)
audio_files
=
[
audio_path
]
mask_files
=
None
mask_latents
=
torch
.
cat
(
mask_latents
,
dim
=
0
)
return
mask_latents
return
audio_files
,
mask_files
def
_get_mask_files_from_config
(
self
):
talk_objects
=
self
.
config
.
get
(
"talk_objects"
)
if
talk_objects
:
mask_files
=
[]
for
idx
,
person
in
enumerate
(
talk_objects
):
mask_path
=
person
.
get
(
"mask"
)
if
mask_path
and
Path
(
mask_path
).
is_file
():
mask_files
.
append
(
str
(
mask_path
))
elif
mask_path
:
logger
.
warning
(
f
"Person
{
idx
}
mask file
{
mask_path
}
does not exist"
)
if
mask_files
:
logger
.
info
(
f
"Loaded
{
len
(
mask_files
)
}
mask files from talk_objects"
)
return
mask_files
logger
.
info
(
"config talk_objects is not specified"
)
return
None
def
_process_single_mask
(
self
,
mask_file
):
def
process_single_mask
(
self
,
mask_file
):
mask_img
=
Image
.
open
(
mask_file
).
convert
(
"RGB"
)
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
...
...
@@ -456,21 +427,21 @@ class WanAudioRunner(WanRunner): # type:ignore
fixed_shape
=
self
.
config
.
get
(
"fixed_shape"
,
None
),
)
logger
.
info
(
f
"[wan_audio] resize_image target_h:
{
h
}
, target_w:
{
w
}
"
)
patched_h
=
h
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
patched_w
=
w
//
self
.
config
.
vae_stride
[
2
]
//
self
.
config
.
patch_size
[
2
]
patched_h
=
h
//
self
.
config
[
"
vae_stride
"
]
[
1
]
//
self
.
config
[
"
patch_size
"
]
[
1
]
patched_w
=
w
//
self
.
config
[
"
vae_stride
"
]
[
2
]
//
self
.
config
[
"
patch_size
"
]
[
2
]
patched_h
,
patched_w
=
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
1
)
self
.
config
.
la
t_h
=
patched_h
*
self
.
config
.
patch_size
[
1
]
self
.
config
.
la
t_w
=
patched_w
*
self
.
config
.
patch_size
[
2
]
laten
t_h
=
patched_h
*
self
.
config
[
"
patch_size
"
]
[
1
]
laten
t_w
=
patched_w
*
self
.
config
[
"
patch_size
"
]
[
2
]
self
.
config
.
tgt_h
=
self
.
config
.
lat_h
*
self
.
config
.
vae_stride
[
1
]
self
.
config
.
tgt_w
=
self
.
config
.
la
t_w
*
self
.
config
.
vae_stride
[
2
]
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
target_shape
=
[
latent_h
*
self
.
config
[
"vae_stride"
][
1
],
laten
t_w
*
self
.
config
[
"
vae_stride
"
]
[
2
]
]
logger
.
info
(
f
"[wan_audio] t
g
t_h:
{
self
.
config
.
tgt_h
}
, tgt_w:
{
self
.
config
.
tgt_w
}
, lat_h:
{
self
.
config
.
la
t_h
}
, lat_w:
{
self
.
config
.
la
t_w
}
"
)
logger
.
info
(
f
"[wan_audio] t
arge
t_h:
{
target_shape
[
0
]
}
, target_w:
{
target_shape
[
1
]
}
, lat
ent
_h:
{
laten
t_h
}
, lat
ent
_w:
{
laten
t_w
}
"
)
ref_img
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
),
mode
=
"bicubic"
)
return
ref_img
ref_img
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
target_shape
[
0
],
target_shape
[
1
]
),
mode
=
"bicubic"
)
return
ref_img
,
latent_shape
,
target_shape
def
run_image_encoder
(
self
,
first_frame
,
last_frame
=
None
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
...
...
@@ -496,20 +467,17 @@ class WanAudioRunner(WanRunner): # type:ignore
return
vae_encoder_out
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_r2v_audio
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
def
_run_input_encoder_local_s2v
(
self
):
img
,
latent_shape
,
target_shape
=
self
.
read_image_input
(
self
.
input_info
.
image_path
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
self
.
input_info
.
target_shape
=
target_shape
# Important: set target_shape in input_info
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img
)
audio_num
,
audio_segments
,
expected_frames
=
self
.
read_audio_input
()
person_mask_latens
=
self
.
read_person_mask
()
self
.
config
.
person_num
=
0
if
person_mask_latens
is
not
None
:
assert
audio_num
==
person_mask_latens
.
size
(
0
),
"audio_num and person_mask_latens.size(0) must be the same"
self
.
config
.
person_num
=
person_mask_latens
.
size
(
0
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
audio_segments
,
expected_frames
,
person_mask_latens
,
audio_num
=
self
.
read_audio_input
(
self
.
input_info
.
audio_path
)
self
.
input_info
.
audio_num
=
audio_num
self
.
input_info
.
with_mask
=
person_mask_latens
is
not
None
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -528,13 +496,13 @@ class WanAudioRunner(WanRunner): # type:ignore
device
=
torch
.
device
(
"cuda"
)
dtype
=
GET_DTYPE
()
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
tgt_h
,
tgt_w
=
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
[
"
target_video_length
"
]
,
tgt_h
,
tgt_w
),
device
=
device
)
if
prev_video
is
not
None
:
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
!=
"wan2.2_audio"
:
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
...
...
@@ -546,7 +514,7 @@ class WanAudioRunner(WanRunner): # type:ignore
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
with
ProfilingContext4DebugL1
(
"vae_encoder in init run segment"
):
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
==
"wan2.2_audio"
:
if
prev_video
is
not
None
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
else
:
...
...
@@ -563,7 +531,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
prev_latents
is
not
None
:
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
tgt_h
}
, tgt_w=
{
tgt_w
}
"
)
prev_latents
=
torch
.
nn
.
functional
.
interpolate
(
prev_latents
,
size
=
(
height
,
width
),
mode
=
"bilinear"
,
align_corners
=
False
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
...
...
@@ -592,8 +560,8 @@ class WanAudioRunner(WanRunner): # type:ignore
super
().
init_run
()
self
.
scheduler
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
prev_video
=
None
if
self
.
config
.
get
(
"return_video"
,
False
)
:
self
.
gen_video_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
],
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
,
3
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
if
self
.
input_info
.
return_result_tensor
:
self
.
gen_video_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
],
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
,
3
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
cut_audio_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
]
*
self
.
_audio_processor
.
audio_frame_rate
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
else
:
self
.
gen_video_final
=
None
...
...
@@ -608,8 +576,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
self
.
segment
=
self
.
inputs
[
"audio_segments"
][
segment_idx
]
self
.
config
.
seed
=
self
.
config
.
seed
+
segment_idx
torch
.
manual_seed
(
self
.
config
.
seed
)
self
.
input_info
.
seed
=
self
.
input_info
.
seed
+
segment_idx
torch
.
manual_seed
(
self
.
input_info
.
seed
)
# logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
if
(
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
))
and
not
hasattr
(
self
,
"audio_encoder"
):
...
...
@@ -627,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
self
.
model
.
scheduler
.
reset
(
self
.
input_info
.
seed
,
self
.
input_info
.
latent_shape
,
self
.
inputs
[
"previmg_encoder_output"
])
@
ProfilingContext4DebugL1
(
"End run segment"
)
def
end_run_segment
(
self
,
segment_idx
):
...
...
@@ -650,7 +618,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
self
.
va_recorder
:
self
.
va_recorder
.
pub_livestream
(
video_seg
,
audio_seg
)
elif
self
.
config
.
get
(
"return_video"
,
False
)
:
elif
self
.
input_info
.
return_result_tensor
:
self
.
gen_video_final
[
self
.
segment
.
start_frame
:
self
.
segment
.
end_frame
].
copy_
(
video_seg
)
self
.
cut_audio_final
[
self
.
segment
.
start_frame
*
self
.
_audio_processor
.
audio_frame_rate
:
self
.
segment
.
end_frame
*
self
.
_audio_processor
.
audio_frame_rate
].
copy_
(
audio_seg
)
...
...
@@ -669,7 +637,7 @@ class WanAudioRunner(WanRunner): # type:ignore
return
rank
,
world_size
def
init_va_recorder
(
self
):
output_video_path
=
self
.
config
.
get
(
"save_video_path"
,
None
)
output_video_path
=
self
.
input_info
.
save_result_path
self
.
va_recorder
=
None
if
isinstance
(
output_video_path
,
dict
):
output_video_path
=
output_video_path
[
"data"
]
...
...
@@ -722,7 +690,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
()
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
self
.
video_segment_num
=
"unlimited"
fetch_timeout
=
self
.
va_reader
.
segment_duration
+
1
...
...
@@ -760,24 +728,20 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
va_recorder
=
None
@
ProfilingContext4DebugL1
(
"Process after vae decoder"
)
def
process_images_after_vae_decoder
(
self
,
save_video
=
False
):
if
self
.
config
.
get
(
"return_video"
,
False
)
:
def
process_images_after_vae_decoder
(
self
):
if
self
.
input_info
.
return_result_tensor
:
audio_waveform
=
self
.
cut_audio_final
.
unsqueeze
(
0
).
unsqueeze
(
0
)
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
return
{
"video"
:
self
.
gen_video_final
,
"audio"
:
comfyui_audio
}
return
{
"video"
:
None
,
"audio"
:
None
}
def
init_modules
(
self
):
super
().
init_modules
()
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_r2v_audio
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
base_model
=
WanAudioModel
(
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
...
...
@@ -814,7 +778,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter
.
to
(
device
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
audio_adapter_offload
,
remove_key
=
"ca"
,
load_from_rank0
=
load_from_rank0
)
weights_dict
=
load_weights
(
self
.
config
[
"
adapter_model_path
"
]
,
cpu_offload
=
audio_adapter_offload
,
remove_key
=
"ca"
,
load_from_rank0
=
load_from_rank0
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
...
...
@@ -824,28 +788,14 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
audio_adapter
=
self
.
load_audio_adapter
()
def
set_target_shape
(
self
):
"""Set target shape for generation"""
ret
=
{}
num_channels_latents
=
16
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
num_channels_latents
=
self
.
config
.
num_channels_latents
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
)
ret
[
"lat_h"
]
=
self
.
config
.
lat_h
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
False
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
def
get_latent_shape_with_lat_hw
(
self
,
latent_h
,
latent_w
):
latent_shape
=
[
self
.
config
.
get
(
"num_channels_latents"
,
16
),
(
self
.
config
[
"target_video_length"
]
-
1
)
//
self
.
config
[
"vae_stride"
][
0
]
+
1
,
latent_h
,
latent_w
,
]
return
latent_shape
@
RUNNER_REGISTER
(
"wan2.2_audio"
)
...
...
@@ -882,7 +832,7 @@ class Wan22AudioRunner(WanAudioRunner):
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
if
self
.
config
.
task
!=
"i2v"
:
if
self
.
config
.
task
not
in
[
"i2v"
,
"s2v"
]
:
return
None
else
:
return
Wan2_2_VAE
(
**
vae_config
)
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
04812de2
...
...
@@ -50,7 +50,7 @@ class WanCausVidRunner(WanRunner):
self
.
scheduler
=
WanStepDistillScheduler
(
self
.
config
)
def
set_target_shape
(
self
):
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
in
[
"i2v"
,
"s2v"
]
:
self
.
config
.
target_shape
=
(
16
,
self
.
config
.
num_frame_per_block
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
)
# i2v需根据input shape重置frame_seq_length
frame_seq_length
=
(
self
.
config
.
lat_h
//
2
)
*
(
self
.
config
.
lat_w
//
2
)
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
04812de2
...
...
@@ -17,28 +17,28 @@ class WanDistillRunner(WanRunner):
super
().
__init__
(
config
)
def
load_transformer
(
self
):
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
init_device
,
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
model
=
WanDistillModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
model
=
WanDistillModel
(
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
init_device
)
return
model
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
[
"
feature_caching
"
]
==
"NoCaching"
:
self
.
scheduler
=
WanStepDistillScheduler
(
self
.
config
)
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'
feature_caching
'
]
}
"
)
class
MultiDistillModelStruct
(
MultiModelStruct
):
...
...
@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct):
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
step_index
<
self
.
boundary_step_index
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
0
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
0
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
0
)
...
...
@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self
.
cur_model_index
=
0
else
:
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
1
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
1
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
1
)
...
...
@@ -81,8 +81,8 @@ class Wan22MoeDistillRunner(WanDistillRunner):
def
load_transformer
(
self
):
use_high_lora
,
use_low_lora
=
False
,
False
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
use_high_lora
=
True
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
...
...
@@ -90,12 +90,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
if
use_high_lora
:
high_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
...
...
@@ -104,19 +104,19 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
high_noise_model
=
Wan22MoeDistillModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"distill_models"
,
"high_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"distill_models"
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
if
use_low_lora
:
low_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
...
...
@@ -125,15 +125,15 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
low_noise_model
=
Wan22MoeDistillModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"distill_models"
,
"low_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"distill_models"
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary_step_index
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"
boundary_step_index
"
]
)
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
[
"
feature_caching
"
]
==
"NoCaching"
:
self
.
scheduler
=
Wan22StepDistillScheduler
(
self
.
config
)
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'
feature_caching
'
]
}
"
)
lightx2v/models/runners/wan/wan_runner.py
View file @
04812de2
...
...
@@ -28,7 +28,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
from
lightx2v.utils.utils
import
best_output_size
@
RUNNER_REGISTER
(
"wan2.1"
)
...
...
@@ -42,7 +42,7 @@ class WanRunner(DefaultRunner):
def
load_transformer
(
self
):
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
init_device
,
)
...
...
@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner):
def
load_image_encoder
(
self
):
image_encoder
=
None
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
[
"
task
"
]
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
# offload config
clip_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
if
clip_offload
:
...
...
@@ -148,13 +148,13 @@ class WanRunner(DefaultRunner):
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
,
"parallel"
:
self
.
config
[
"
parallel
"
]
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"load_from_rank0"
:
self
.
config
.
get
(
"load_from_rank0"
,
False
),
}
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
]:
if
self
.
config
[
"
task
"
]
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
,
"s2v"
]:
return
None
else
:
return
self
.
vae_cls
(
**
vae_config
)
...
...
@@ -170,7 +170,7 @@ class WanRunner(DefaultRunner):
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
,
"parallel"
:
self
.
config
[
"
parallel
"
]
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
...
...
@@ -192,9 +192,9 @@ class WanRunner(DefaultRunner):
return
vae_encoder
,
vae_decoder
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
[
"
feature_caching
"
]
==
"NoCaching"
:
scheduler_class
=
WanScheduler
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
elif
self
.
config
[
"
feature_caching
"
]
==
"TaylorSeer"
:
scheduler_class
=
WanSchedulerTaylorCaching
elif
self
.
config
.
feature_caching
in
[
"Tea"
,
"Ada"
,
"Custom"
,
"FirstBlock"
,
"DualBlock"
,
"DynamicBlock"
,
"Mag"
]:
scheduler_class
=
WanSchedulerCaching
...
...
@@ -206,26 +206,28 @@ class WanRunner(DefaultRunner):
else
:
self
.
scheduler
=
scheduler_class
(
self
.
config
)
def
run_text_encoder
(
self
,
text
,
img
=
None
):
def
run_text_encoder
(
self
,
input_info
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
text_encoders
=
self
.
load_text_encoder
()
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
prompt
=
input_info
.
prompt_enhanced
if
self
.
config
[
"use_prompt_enhancer"
]
else
input_info
.
prompt
neg_prompt
=
input_info
.
negative_prompt
if
self
.
config
[
"cfg_parallel"
]:
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
context
=
self
.
text_encoders
[
0
].
infer
([
tex
t
])
context
=
self
.
text_encoders
[
0
].
infer
([
promp
t
])
context
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
text_encoder_output
=
{
"context"
:
context
}
else
:
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n
eg
_prompt
])
context_null
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context_null
])
text_encoder_output
=
{
"context_null"
:
context_null
}
else
:
context
=
self
.
text_encoders
[
0
].
infer
([
tex
t
])
context
=
self
.
text_encoders
[
0
].
infer
([
promp
t
])
context
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n
eg
_prompt
])
context_null
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context_null
])
text_encoder_output
=
{
"context"
:
context
,
...
...
@@ -255,22 +257,22 @@ class WanRunner(DefaultRunner):
def
run_vae_encoder
(
self
,
first_frame
,
last_frame
=
None
):
h
,
w
=
first_frame
.
shape
[
2
:]
aspect_ratio
=
h
/
w
max_area
=
self
.
config
.
target_height
*
self
.
config
.
target_width
lat_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
1
])
lat_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
self
.
config
.
vae_stride
[
2
]
//
self
.
config
.
patch_size
[
2
]
*
self
.
config
.
patch_size
[
2
])
max_area
=
self
.
config
[
"target_height"
]
*
self
.
config
[
"target_width"
]
latent_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
1
]
//
self
.
config
[
"patch_size"
][
1
]
*
self
.
config
[
"patch_size"
][
1
])
latent_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
2
]
//
self
.
config
[
"patch_size"
][
2
]
*
self
.
config
[
"patch_size"
][
2
])
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
# Important: latent_shape is used to set the input_info
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
assert
last_frame
is
None
self
.
config
.
lat_h
,
self
.
config
.
lat_w
=
lat_h
,
lat_w
vae_encode_out_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
lat
_h
,
lat_w
=
(
int
(
self
.
config
.
la
t_h
*
self
.
config
.
resolution_rate
[
i
])
//
2
*
2
,
int
(
self
.
config
.
la
t_w
*
self
.
config
.
resolution_rate
[
i
])
//
2
*
2
,
lat
ent_h_tmp
,
latent_w_tmp
=
(
int
(
laten
t_h
*
self
.
config
[
"
resolution_rate
"
]
[
i
])
//
2
*
2
,
int
(
laten
t_w
*
self
.
config
[
"
resolution_rate
"
]
[
i
])
//
2
*
2
,
)
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
lat
_h
,
lat_w
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
self
.
config
.
lat_h
,
self
.
config
.
la
t_w
))
return
vae_encode_out_list
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
lat
ent_h_tmp
,
latent_w_tmp
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
laten
t_w
))
return
vae_encode_out_list
,
latent_shape
else
:
if
last_frame
is
not
None
:
first_frame_size
=
first_frame
.
shape
[
2
:]
...
...
@@ -282,16 +284,15 @@ class WanRunner(DefaultRunner):
round
(
last_frame_size
[
1
]
*
last_frame_resize_ratio
),
]
last_frame
=
TF
.
center_crop
(
last_frame
,
last_frame_size
)
self
.
config
.
lat_h
,
self
.
config
.
lat_w
=
lat_h
,
lat_w
vae_encoder_out
=
self
.
get_vae_encoder_output
(
first_frame
,
lat_h
,
lat_w
,
last_frame
)
return
vae_encoder_out
vae_encoder_out
=
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
latent_w
,
last_frame
)
return
vae_encoder_out
,
latent_shape
def
get_vae_encoder_output
(
self
,
first_frame
,
lat_h
,
lat_w
,
last_frame
=
None
):
h
=
lat_h
*
self
.
config
.
vae_stride
[
1
]
w
=
lat_w
*
self
.
config
.
vae_stride
[
2
]
h
=
lat_h
*
self
.
config
[
"
vae_stride
"
]
[
1
]
w
=
lat_w
*
self
.
config
[
"
vae_stride
"
]
[
2
]
msk
=
torch
.
ones
(
1
,
self
.
config
.
target_video_length
,
self
.
config
[
"
target_video_length
"
]
,
lat_h
,
lat_w
,
device
=
torch
.
device
(
"cuda"
),
...
...
@@ -312,7 +313,7 @@ class WanRunner(DefaultRunner):
vae_input
=
torch
.
concat
(
[
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
2
,
h
,
w
),
torch
.
zeros
(
3
,
self
.
config
[
"
target_video_length
"
]
-
2
,
h
,
w
),
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
dim
=
1
,
...
...
@@ -321,7 +322,7 @@ class WanRunner(DefaultRunner):
vae_input
=
torch
.
concat
(
[
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
1
,
h
,
w
),
torch
.
zeros
(
3
,
self
.
config
[
"
target_video_length
"
]
-
1
,
h
,
w
),
],
dim
=
1
,
).
cuda
()
...
...
@@ -345,32 +346,23 @@ class WanRunner(DefaultRunner):
"image_encoder_output"
:
image_encoder_output
,
}
def
set_target_shape
(
self
):
num_channels_latents
=
self
.
config
.
get
(
"num_channels_latents"
,
16
)
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
)
elif
self
.
config
.
task
==
"t2v"
:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
int
(
self
.
config
.
target_height
)
//
self
.
config
.
vae_stride
[
1
],
int
(
self
.
config
.
target_width
)
//
self
.
config
.
vae_stride
[
2
],
)
def
save_video_func
(
self
,
images
):
cache_video
(
tensor
=
images
,
save_file
=
self
.
config
.
save_video_path
,
fps
=
self
.
config
.
get
(
"fps"
,
16
),
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
),
)
def
get_latent_shape_with_lat_hw
(
self
,
latent_h
,
latent_w
):
latent_shape
=
[
self
.
config
.
get
(
"num_channels_latents"
,
16
),
(
self
.
config
[
"target_video_length"
]
-
1
)
//
self
.
config
[
"vae_stride"
][
0
]
+
1
,
latent_h
,
latent_w
,
]
return
latent_shape
def
get_latent_shape_with_target_hw
(
self
,
target_h
,
target_w
):
latent_shape
=
[
self
.
config
.
get
(
"num_channels_latents"
,
16
),
(
self
.
config
[
"target_video_length"
]
-
1
)
//
self
.
config
[
"vae_stride"
][
0
]
+
1
,
int
(
target_h
)
//
self
.
config
[
"vae_stride"
][
1
],
int
(
target_w
)
//
self
.
config
[
"vae_stride"
][
2
],
]
return
latent_shape
class
MultiModelStruct
:
...
...
@@ -400,7 +392,7 @@ class MultiModelStruct:
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
>=
self
.
boundary_timestep
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
0
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
0
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
0
)
...
...
@@ -410,7 +402,7 @@ class MultiModelStruct:
self
.
cur_model_index
=
0
else
:
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
1
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
1
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
1
)
...
...
@@ -434,20 +426,20 @@ class Wan22MoeRunner(WanRunner):
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
low_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
base_name
=
os
.
path
.
basename
(
lora_path
)
...
...
@@ -464,7 +456,7 @@ class Wan22MoeRunner(WanRunner):
else
:
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"
boundary
"
]
)
@
RUNNER_REGISTER
(
"wan2.2"
)
...
...
lightx2v/models/runners/wan/wan_sf_runner.py
View file @
04812de2
...
...
@@ -9,11 +9,10 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from
lightx2v.models.schedulers.wan.self_forcing.scheduler
import
WanSFScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_sf
import
WanSFVAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
torch
.
manual_seed
(
42
)
@
RUNNER_REGISTER
(
"wan2.1_sf"
)
class
WanSFRunner
(
WanRunner
):
...
...
@@ -59,40 +58,37 @@ class WanSFRunner(WanRunner):
gc
.
collect
()
return
images
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
()
total_blocks
=
self
.
scheduler
.
num_blocks
gen_videos
=
[]
for
seg_index
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"==> segment_index:
{
seg_index
+
1
}
/
{
total_blocks
}
"
)
total_steps
=
len
(
self
.
scheduler
.
denoising_step_list
)
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4DebugL1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
seg_index
,
step_index
=
step_index
,
is_rerun
=
False
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4DebugL1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
latents
=
self
.
model
.
scheduler
.
stream_output
gen_videos
.
append
(
self
.
run_vae_decoder
(
latents
))
# rerun with timestep zero to update KV cache using clean context
with
ProfilingContext4DebugL1
(
"step_pre_in_rerun"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
seg_index
,
step_index
=
step_index
,
is_rerun
=
True
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main_in_rerun"
):
def
init_run
(
self
):
super
().
init_run
()
@
ProfilingContext4DebugL1
(
"End run segment"
)
def
end_run_segment
(
self
,
segment_idx
=
None
):
with
ProfilingContext4DebugL1
(
"step_pre_in_rerun"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
segment_idx
,
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
,
is_rerun
=
True
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main_in_rerun"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
gen_video_final
=
torch
.
cat
([
self
.
gen_video_final
,
self
.
gen_video
],
dim
=
0
)
if
self
.
gen_video_final
is
not
None
else
self
.
gen_video
@
peak_memory_decorator
def
run_segment
(
self
,
total_steps
=
None
):
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
# only for single segment, check stop signal every step
if
self
.
video_segment_num
==
1
:
self
.
check_stop
()
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4DebugL1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
self
.
segment_idx
,
step_index
=
step_index
,
is_rerun
=
False
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
gen_video
=
torch
.
cat
(
gen_videos
,
dim
=
0
)
with
ProfilingContext4DebugL1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
self
.
progress_callback
(((
step_index
+
1
)
/
total_steps
)
*
100
,
100
)
self
.
end_run
()
return
self
.
model
.
scheduler
.
stream_output
lightx2v/models/runners/wan/wan_vace_runner.py
View file @
04812de2
...
...
@@ -154,10 +154,10 @@ class WanVaceRunner(WanRunner):
return
[
torch
.
cat
([
zz
,
mm
],
dim
=
0
)
for
zz
,
mm
in
zip
(
cat_latents
,
result_masks
)]
def
set_
targe
t_shape
(
self
):
targe
t_shape
=
self
.
latent_shape
targe
t_shape
[
0
]
=
int
(
targe
t_shape
[
0
]
/
2
)
self
.
config
.
target_shape
=
targe
t_shape
def
set_
input_info_laten
t_shape
(
self
):
laten
t_shape
=
self
.
latent_shape
laten
t_shape
[
0
]
=
int
(
laten
t_shape
[
0
]
/
2
)
return
laten
t_shape
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
...
...
lightx2v/models/schedulers/scheduler.py
View file @
04812de2
...
...
@@ -6,8 +6,8 @@ class BaseScheduler:
self
.
config
=
config
self
.
latents
=
None
self
.
step_index
=
0
self
.
infer_steps
=
config
.
infer_steps
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
infer_steps
=
config
[
"
infer_steps
"
]
self
.
caching_records
=
[
True
]
*
config
[
"
infer_steps
"
]
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
infer_condition
=
True
# cfg status
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
04812de2
...
...
@@ -13,8 +13,8 @@ class EulerScheduler(WanScheduler):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
if
self
.
config
.
parallel
:
self
.
sp_size
=
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
if
self
.
config
[
"
parallel
"
]
:
self
.
sp_size
=
self
.
config
[
"
parallel
"
]
.
get
(
"seq_p_size"
,
1
)
else
:
self
.
sp_size
=
1
...
...
@@ -33,11 +33,11 @@ class EulerScheduler(WanScheduler):
if
self
.
audio_adapter
.
cpu_offload
:
self
.
audio_adapter
.
time_embedding
.
to
(
"cpu"
)
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
==
"wan2.2_audio"
:
_
,
lat_f
,
lat_h
,
lat_w
=
self
.
latents
.
shape
F
=
(
lat_f
-
1
)
*
self
.
config
.
vae_stride
[
0
]
+
1
per_latent_token_len
=
lat_h
*
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
max_seq_len
=
((
F
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
per_latent_token_len
F
=
(
lat_f
-
1
)
*
self
.
config
[
"
vae_stride
"
]
[
0
]
+
1
per_latent_token_len
=
lat_h
*
lat_w
//
(
self
.
config
[
"
patch_size
"
]
[
1
]
*
self
.
config
[
"
patch_size
"
]
[
2
])
max_seq_len
=
((
F
-
1
)
//
self
.
config
[
"
vae_stride
"
]
[
0
]
+
1
)
*
per_latent_token_len
max_seq_len
=
int
(
math
.
ceil
(
max_seq_len
/
self
.
sp_size
))
*
self
.
sp_size
temp_ts
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
...
...
@@ -55,13 +55,13 @@ class EulerScheduler(WanScheduler):
dim
=
1
,
)
def
prepare_latents
(
self
,
targe
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
def
prepare_latents
(
self
,
seed
,
laten
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
targe
t_shape
[
0
],
targe
t_shape
[
1
],
targe
t_shape
[
2
],
targe
t_shape
[
3
],
laten
t_shape
[
0
],
laten
t_shape
[
1
],
laten
t_shape
[
2
],
laten
t_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
...
...
@@ -71,8 +71,8 @@ class EulerScheduler(WanScheduler):
if
self
.
prev_latents
is
not
None
:
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
prev_latents
+
self
.
mask
*
self
.
latents
def
prepare
(
self
,
previmg
_encoder_output
=
None
):
self
.
prepare_latents
(
se
lf
.
config
.
targe
t_shape
,
dtype
=
torch
.
float32
)
def
prepare
(
self
,
seed
,
latent_shape
,
image
_encoder_output
=
None
):
self
.
prepare_latents
(
se
ed
,
laten
t_shape
,
dtype
=
torch
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
...
...
@@ -93,11 +93,11 @@ class EulerScheduler(WanScheduler):
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
and
self
.
prev_latents
is
not
None
:
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
prev_latents
+
self
.
mask
*
self
.
latents
def
reset
(
self
,
previmg
_encoder_output
=
None
):
def
reset
(
self
,
seed
,
latent_shape
,
image
_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
self
.
prev_latents
=
previmg
_encoder_output
[
"prev_latents"
]
self
.
prev_len
=
previmg
_encoder_output
[
"prev_len"
]
self
.
prepare_latents
(
se
lf
.
config
.
targe
t_shape
,
dtype
=
torch
.
float32
)
self
.
prev_latents
=
image
_encoder_output
[
"prev_latents"
]
self
.
prev_len
=
image
_encoder_output
[
"prev_len"
]
self
.
prepare_latents
(
se
ed
,
laten
t_shape
,
dtype
=
torch
.
float32
)
def
unsqueeze_to_ndim
(
self
,
in_tensor
,
tgt_n_dim
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
04812de2
...
...
@@ -19,16 +19,16 @@ class WanScheduler4ChangingResolution:
config
[
"changing_resolution_steps"
]
=
[
config
.
infer_steps
//
2
]
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
targe
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
def
prepare_latents
(
self
,
seed
,
laten
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
torch
.
randn
(
targe
t_shape
[
0
],
targe
t_shape
[
1
],
int
(
targe
t_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
targe
t_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
laten
t_shape
[
0
],
laten
t_shape
[
1
],
int
(
laten
t_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
laten
t_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
...
...
@@ -38,10 +38,10 @@ class WanScheduler4ChangingResolution:
# add original resolution latents
self
.
latents_list
.
append
(
torch
.
randn
(
targe
t_shape
[
0
],
targe
t_shape
[
1
],
targe
t_shape
[
2
],
targe
t_shape
[
3
],
laten
t_shape
[
0
],
laten
t_shape
[
1
],
laten
t_shape
[
2
],
laten
t_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
04812de2
import
gc
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -12,22 +11,22 @@ class WanScheduler(BaseScheduler):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
infer_steps
=
self
.
config
[
"
infer_steps
"
]
self
.
target_video_length
=
self
.
config
[
"
target_video_length
"
]
self
.
sample_shift
=
self
.
config
[
"
sample_shift
"
]
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
disable_corrector
=
[]
self
.
solver_order
=
2
self
.
noise_pred
=
None
self
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
self
.
caching_records_2
=
[
True
]
*
self
.
config
.
infer_steps
self
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"
infer_steps
"
]
def
prepare
(
self
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
vae_encoder_out
=
image_encoder_output
[
"vae_encoder_out"
]
self
.
prepare_latents
(
se
lf
.
config
.
targe
t_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
se
ed
,
laten
t_shape
,
dtype
=
torch
.
float32
)
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
...
...
@@ -48,18 +47,18 @@ class WanScheduler(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
targe
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
def
prepare_latents
(
self
,
seed
,
laten
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
targe
t_shape
[
0
],
targe
t_shape
[
1
],
targe
t_shape
[
2
],
targe
t_shape
[
3
],
laten
t_shape
[
0
],
laten
t_shape
[
1
],
laten
t_shape
[
2
],
laten
t_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
mask
=
masks_like
(
self
.
latents
,
zero
=
True
)
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
...
...
@@ -117,7 +116,7 @@ class WanScheduler(BaseScheduler):
x0_pred
=
sample
-
sigma_t
*
model_output
return
x0_pred
def
reset
(
self
,
step_index
=
None
):
def
reset
(
self
,
seed
,
latent_shape
,
step_index
=
None
):
if
step_index
is
not
None
:
self
.
step_index
=
step_index
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
...
...
@@ -126,9 +125,7 @@ class WanScheduler(BaseScheduler):
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
def
multistep_uni_p_bh_update
(
self
,
...
...
@@ -325,7 +322,7 @@ class WanScheduler(BaseScheduler):
def
step_pre
(
self
,
step_index
):
super
().
step_pre
(
step_index
)
self
.
timestep_input
=
torch
.
stack
([
self
.
timesteps
[
self
.
step_index
]])
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
timestep_input
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
def
step_post
(
self
):
...
...
@@ -367,5 +364,5 @@ class WanScheduler(BaseScheduler):
self
.
lower_order_nums
+=
1
self
.
latents
=
prev_sample
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
04812de2
...
...
@@ -9,24 +9,25 @@ class WanSFScheduler(WanScheduler):
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
.
sf_config
.
num_frame_per_block
self
.
num_output_frames
=
self
.
config
.
sf_config
.
num_output_frames
self
.
num_frame_per_block
=
self
.
config
[
"
sf_config
"
][
"
num_frame_per_block
"
]
self
.
num_output_frames
=
self
.
config
[
"
sf_config
"
][
"
num_output_frames
"
]
self
.
num_blocks
=
self
.
num_output_frames
//
self
.
num_frame_per_block
self
.
denoising_step_list
=
self
.
config
.
sf_config
.
denoising_step_list
self
.
denoising_step_list
=
self
.
config
[
"sf_config"
][
"denoising_step_list"
]
self
.
infer_steps
=
len
(
self
.
denoising_step_list
)
self
.
all_num_frames
=
[
self
.
num_frame_per_block
]
*
self
.
num_blocks
self
.
num_input_frames
=
0
self
.
denoising_strength
=
1.0
self
.
sigma_max
=
1.0
self
.
sigma_min
=
0
self
.
sf_shift
=
self
.
config
.
sf_config
.
shift
self
.
sf_shift
=
self
.
config
[
"
sf_config
"
][
"
shift
"
]
self
.
inverse_timesteps
=
False
self
.
extra_one_step
=
True
self
.
reverse_sigmas
=
False
self
.
num_inference_steps
=
self
.
config
.
sf_config
.
num_inference_steps
self
.
num_inference_steps
=
self
.
config
[
"
sf_config
"
][
"
num_inference_steps
"
]
self
.
context_noise
=
0
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
self
.
config
.
targe
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
laten
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
timesteps
=
[]
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
...
...
@@ -39,7 +40,7 @@ class WanSFScheduler(WanScheduler):
timesteps
.
append
(
frame_steps
)
self
.
timesteps
=
timesteps
self
.
noise_pred
=
torch
.
zeros
(
self
.
config
.
targe
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
self
.
noise_pred
=
torch
.
zeros
(
laten
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
if
self
.
extra_one_step
:
...
...
@@ -91,7 +92,7 @@ class WanSFScheduler(WanScheduler):
x0_pred
=
x0_pred
.
to
(
original_dtype
)
# add noise
if
self
.
step_index
<
len
(
self
.
denoising_step_list
)
-
1
:
if
self
.
step_index
<
self
.
infer_steps
-
1
:
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment