Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
04812de2
"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "a385ee27bd0025781eba61578889e470a1c027fb"
Unverified
Commit
04812de2
authored
Sep 29, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Sep 29, 2025
Browse files
Refactor Config System (#338)
parent
6a658f42
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
329 additions
and
417 deletions
+329
-417
lightx2v/models/networks/wan/sf_model.py
lightx2v/models/networks/wan/sf_model.py
+2
-2
lightx2v/models/networks/wan/weights/pre_weights.py
lightx2v/models/networks/wan/weights/pre_weights.py
+4
-4
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+3
-3
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+2
-23
lightx2v/models/runners/cogvideox/cogvidex_runner.py
lightx2v/models/runners/cogvideox/cogvidex_runner.py
+0
-9
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+69
-59
lightx2v/models/runners/hunyuan/hunyuan_runner.py
lightx2v/models/runners/hunyuan/hunyuan_runner.py
+0
-4
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+1
-1
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+1
-1
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+74
-124
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+1
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+19
-19
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+58
-66
lightx2v/models/runners/wan/wan_sf_runner.py
lightx2v/models/runners/wan/wan_sf_runner.py
+32
-36
lightx2v/models/runners/wan/wan_vace_runner.py
lightx2v/models/runners/wan/wan_vace_runner.py
+4
-4
lightx2v/models/schedulers/scheduler.py
lightx2v/models/schedulers/scheduler.py
+2
-2
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+18
-18
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+10
-10
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+19
-22
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+10
-9
No files found.
lightx2v/models/networks/wan/sf_model.py
View file @
04812de2
...
@@ -14,8 +14,8 @@ class WanSFModel(WanModel):
...
@@ -14,8 +14,8 @@ class WanSFModel(WanModel):
self
.
to_cuda
()
self
.
to_cuda
()
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
sf_confg
=
self
.
config
.
sf_config
sf_confg
=
self
.
config
[
"
sf_config
"
]
file_path
=
os
.
path
.
join
(
self
.
config
.
sf_model_path
,
f
"checkpoints/self_forcing_
{
sf_confg
.
sf_type
}
.pt"
)
file_path
=
os
.
path
.
join
(
self
.
config
[
"
sf_model_path
"
]
,
f
"checkpoints/self_forcing_
{
sf_confg
[
'
sf_type
'
]
}
.pt"
)
_weight_dict
=
torch
.
load
(
file_path
)[
"generator_ema"
]
_weight_dict
=
torch
.
load
(
file_path
)[
"generator_ema"
]
weight_dict
=
{}
weight_dict
=
{}
for
k
,
v
in
_weight_dict
.
items
():
for
k
,
v
in
_weight_dict
.
items
():
...
...
lightx2v/models/networks/wan/weights/pre_weights.py
View file @
04812de2
...
@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
...
@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
)
)
if
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
config
.
get
(
"use_image_encoder"
,
True
):
if
config
[
"
task
"
]
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
self
.
add_module
(
"proj_0"
,
"proj_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
),
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
),
...
@@ -58,7 +58,7 @@ class WanPreWeights(WeightModule):
...
@@ -58,7 +58,7 @@ class WanPreWeights(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
),
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
),
)
)
if
config
.
model_cls
==
"wan2.1_distill"
and
config
.
get
(
"enable_dynamic_cfg"
,
False
):
if
config
[
"
model_cls
"
]
==
"wan2.1_distill"
and
config
.
get
(
"enable_dynamic_cfg"
,
False
):
self
.
add_module
(
self
.
add_module
(
"cfg_cond_proj_1"
,
"cfg_cond_proj_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_embedding.linear_1.weight"
,
"guidance_embedding.linear_1.bias"
),
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_embedding.linear_1.weight"
,
"guidance_embedding.linear_1.bias"
),
...
@@ -68,12 +68,12 @@ class WanPreWeights(WeightModule):
...
@@ -68,12 +68,12 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_embedding.linear_2.weight"
,
"guidance_embedding.linear_2.bias"
),
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_embedding.linear_2.weight"
,
"guidance_embedding.linear_2.bias"
),
)
)
if
config
.
task
==
"flf2v"
and
config
.
get
(
"use_image_encoder"
,
True
):
if
config
[
"
task
"
]
==
"flf2v"
and
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
self
.
add_module
(
"emb_pos"
,
"emb_pos"
,
TENSOR_REGISTER
[
"Default"
](
f
"img_emb.emb_pos"
),
TENSOR_REGISTER
[
"Default"
](
f
"img_emb.emb_pos"
),
)
)
if
config
.
task
==
"animate"
:
if
config
[
"
task
"
]
==
"animate"
:
self
.
add_module
(
self
.
add_module
(
"pose_patch_embedding"
,
"pose_patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"pose_patch_embedding.weight"
,
"pose_patch_embedding.bias"
,
stride
=
self
.
patch_size
),
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"pose_patch_embedding.weight"
,
"pose_patch_embedding.bias"
,
stride
=
self
.
patch_size
),
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
04812de2
...
@@ -60,7 +60,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -60,7 +60,7 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
if
self
.
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
self
.
config
.
dit_quantized_ckpt
,
f
"block_
{
block_index
}
.safetensors"
)
lazy_load_path
=
os
.
path
.
join
(
self
.
config
[
"
dit_quantized_ckpt
"
]
,
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
else
:
self
.
lazy_load_file
=
None
self
.
lazy_load_file
=
None
...
@@ -197,7 +197,7 @@ class WanSelfAttention(WeightModule):
...
@@ -197,7 +197,7 @@ class WanSelfAttention(WeightModule):
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
.
parallel
.
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
parallel
"
]
.
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
if
self
.
quant_method
in
[
"advanced_ptq"
]:
if
self
.
quant_method
in
[
"advanced_ptq"
]:
self
.
add_module
(
self
.
add_module
(
...
@@ -296,7 +296,7 @@ class WanCrossAttention(WeightModule):
...
@@ -296,7 +296,7 @@ class WanCrossAttention(WeightModule):
)
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
[
"
task
"
]
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
self
.
add_module
(
"cross_attn_k_img"
,
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
...
...
lightx2v/models/runners/base_runner.py
View file @
04812de2
...
@@ -3,8 +3,6 @@ from abc import ABC
...
@@ -3,8 +3,6 @@ from abc import ABC
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lightx2v.utils.utils
import
save_videos_grid
class
BaseRunner
(
ABC
):
class
BaseRunner
(
ABC
):
"""Abstract base class for all Runners
"""Abstract base class for all Runners
...
@@ -15,6 +13,7 @@ class BaseRunner(ABC):
...
@@ -15,6 +13,7 @@ class BaseRunner(ABC):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
self
.
vae_encoder_need_img_original
=
False
self
.
vae_encoder_need_img_original
=
False
self
.
input_info
=
None
def
load_transformer
(
self
):
def
load_transformer
(
self
):
"""Load transformer model
"""Load transformer model
...
@@ -100,26 +99,6 @@ class BaseRunner(ABC):
...
@@ -100,26 +99,6 @@ class BaseRunner(ABC):
"""Initialize scheduler"""
"""Initialize scheduler"""
pass
pass
def
set_target_shape
(
self
):
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return
{}
def
save_video_func
(
self
,
images
):
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid
(
images
,
self
.
config
.
get
(
"save_video_path"
,
"./output.mp4"
),
n_rows
=
1
,
fps
=
self
.
config
.
get
(
"fps"
,
8
))
def
load_vae_decoder
(
self
):
def
load_vae_decoder
(
self
):
"""Load VAE decoder
"""Load VAE decoder
...
@@ -146,7 +125,7 @@ class BaseRunner(ABC):
...
@@ -146,7 +125,7 @@ class BaseRunner(ABC):
pass
pass
def
end_run_segment
(
self
,
segment_idx
=
None
):
def
end_run_segment
(
self
,
segment_idx
=
None
):
pass
self
.
gen_video_final
=
self
.
gen_video
def
end_run
(
self
):
def
end_run
(
self
):
pass
pass
...
...
lightx2v/models/runners/cogvideox/cogvidex_runner.py
View file @
04812de2
import
imageio
import
numpy
as
np
from
lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model
import
T5EncoderModel_v1_1_xxl
from
lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model
import
T5EncoderModel_v1_1_xxl
from
lightx2v.models.networks.cogvideox.model
import
CogvideoxModel
from
lightx2v.models.networks.cogvideox.model
import
CogvideoxModel
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
...
@@ -72,9 +69,3 @@ class CogvideoxRunner(DefaultRunner):
...
@@ -72,9 +69,3 @@ class CogvideoxRunner(DefaultRunner):
)
)
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
return
ret
def
save_video_func
(
self
,
images
):
with
imageio
.
get_writer
(
self
.
config
.
save_video_path
,
fps
=
16
)
as
writer
:
for
pil_image
in
images
:
frame_np
=
np
.
array
(
pil_image
,
dtype
=
np
.
uint8
)
writer
.
append_data
(
frame_np
)
lightx2v/models/runners/default_runner.py
View file @
04812de2
...
@@ -22,13 +22,13 @@ class DefaultRunner(BaseRunner):
...
@@ -22,13 +22,13 @@ class DefaultRunner(BaseRunner):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
has_prompt_enhancer
=
False
self
.
has_prompt_enhancer
=
False
self
.
progress_callback
=
None
self
.
progress_callback
=
None
if
self
.
config
.
task
==
"t2v"
and
self
.
config
.
get
(
"sub_servers"
,
{}).
get
(
"prompt_enhancer"
)
is
not
None
:
if
self
.
config
[
"
task
"
]
==
"t2v"
and
self
.
config
.
get
(
"sub_servers"
,
{}).
get
(
"prompt_enhancer"
)
is
not
None
:
self
.
has_prompt_enhancer
=
True
self
.
has_prompt_enhancer
=
True
if
not
self
.
check_sub_servers
(
"prompt_enhancer"
):
if
not
self
.
check_sub_servers
(
"prompt_enhancer"
):
self
.
has_prompt_enhancer
=
False
self
.
has_prompt_enhancer
=
False
logger
.
warning
(
"No prompt enhancer server available, disable prompt enhancer."
)
logger
.
warning
(
"No prompt enhancer server available, disable prompt enhancer."
)
if
not
self
.
has_prompt_enhancer
:
if
not
self
.
has_prompt_enhancer
:
self
.
config
.
use_prompt_enhancer
=
False
self
.
config
[
"
use_prompt_enhancer
"
]
=
False
self
.
set_init_device
()
self
.
set_init_device
()
self
.
init_scheduler
()
self
.
init_scheduler
()
...
@@ -49,12 +49,15 @@ class DefaultRunner(BaseRunner):
...
@@ -49,12 +49,15 @@ class DefaultRunner(BaseRunner):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_vace
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_vace
elif
self
.
config
[
"task"
]
==
"animate"
:
elif
self
.
config
[
"task"
]
==
"animate"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_animate
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_animate
elif
self
.
config
[
"task"
]
==
"s2v"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_s2v
self
.
config
.
lock
()
# lock config to avoid modification
if
self
.
config
.
get
(
"compile"
,
False
):
if
self
.
config
.
get
(
"compile"
,
False
):
logger
.
info
(
f
"[Compile] Compile all shapes:
{
self
.
config
.
get
(
'compile_shapes'
,
[])
}
"
)
logger
.
info
(
f
"[Compile] Compile all shapes:
{
self
.
config
.
get
(
'compile_shapes'
,
[])
}
"
)
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
def
set_init_device
(
self
):
def
set_init_device
(
self
):
if
self
.
config
.
cpu_offload
:
if
self
.
config
[
"
cpu_offload
"
]
:
self
.
init_device
=
torch
.
device
(
"cpu"
)
self
.
init_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
init_device
=
torch
.
device
(
"cuda"
)
self
.
init_device
=
torch
.
device
(
"cuda"
)
...
@@ -96,21 +99,23 @@ class DefaultRunner(BaseRunner):
...
@@ -96,21 +99,23 @@ class DefaultRunner(BaseRunner):
return
len
(
available_servers
)
>
0
return
len
(
available_servers
)
>
0
def
set_inputs
(
self
,
inputs
):
def
set_inputs
(
self
,
inputs
):
self
.
config
[
"prompt"
]
=
inputs
.
get
(
"prompt"
,
""
)
self
.
input_info
.
seed
=
inputs
.
get
(
"seed"
,
42
)
self
.
config
[
"use_prompt_enhancer"
]
=
False
self
.
input_info
.
prompt
=
inputs
.
get
(
"prompt"
,
""
)
if
self
.
has_prompt_enhancer
:
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"use_prompt_enhancer"
]
=
inputs
.
get
(
"use_prompt_enhancer"
,
False
)
# Reset use_prompt_enhancer from clinet side.
self
.
input_info
.
prompt_enhanced
=
inputs
.
get
(
"prompt_enhanced"
,
""
)
self
.
config
[
"negative_prompt"
]
=
inputs
.
get
(
"negative_prompt"
,
""
)
self
.
input_info
.
negative_prompt
=
inputs
.
get
(
"negative_prompt"
,
""
)
self
.
config
[
"image_path"
]
=
inputs
.
get
(
"image_path"
,
""
)
if
"image_path"
in
self
.
input_info
.
__dataclass_fields__
:
self
.
config
[
"save_video_path"
]
=
inputs
.
get
(
"save_video_path"
,
""
)
self
.
input_info
.
image_path
=
inputs
.
get
(
"image_path"
,
""
)
self
.
config
[
"infer_steps"
]
=
inputs
.
get
(
"infer_steps"
,
self
.
config
.
get
(
"infer_steps"
,
5
))
if
"audio_path"
in
self
.
input_info
.
__dataclass_fields__
:
self
.
config
[
"target_video_length"
]
=
inputs
.
get
(
"target_video_length"
,
self
.
config
.
get
(
"target_video_length"
,
81
))
self
.
input_info
.
audio_path
=
inputs
.
get
(
"audio_path"
,
""
)
self
.
config
[
"seed"
]
=
inputs
.
get
(
"seed"
,
self
.
config
.
get
(
"seed"
,
42
))
if
"video_path"
in
self
.
input_info
.
__dataclass_fields__
:
self
.
config
[
"audio_path"
]
=
inputs
.
get
(
"audio_path"
,
""
)
# for wan-audio
self
.
input_info
.
video_path
=
inputs
.
get
(
"video_path"
,
""
)
self
.
config
[
"video_duration"
]
=
inputs
.
get
(
"video_duration"
,
5
)
# for wan-audio
self
.
input_info
.
save_result_path
=
inputs
.
get
(
"save_result_path"
,
""
)
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
def
set_config
(
self
,
config_modify
):
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
logger
.
info
(
f
"modify config:
{
config_modify
}
"
)
with
self
.
config
.
temporarily_unlocked
():
self
.
config
.
update
(
config_modify
)
def
set_progress_callback
(
self
,
callback
):
def
set_progress_callback
(
self
,
callback
):
self
.
progress_callback
=
callback
self
.
progress_callback
=
callback
...
@@ -146,6 +151,7 @@ class DefaultRunner(BaseRunner):
...
@@ -146,6 +151,7 @@ class DefaultRunner(BaseRunner):
def
end_run
(
self
):
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
self
.
model
.
scheduler
.
clear
()
del
self
.
inputs
del
self
.
inputs
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
...
@@ -162,23 +168,24 @@ class DefaultRunner(BaseRunner):
...
@@ -162,23 +168,24 @@ class DefaultRunner(BaseRunner):
else
:
else
:
img_ori
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img_ori
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
self
.
input_info
.
original_size
=
img_ori
.
size
return
img
,
img_ori
return
img
,
img_ori
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_i2v
(
self
):
def
_run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
,
img_ori
=
self
.
read_image_input
(
self
.
input_info
.
image_path
)
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img_ori
if
self
.
vae_encoder_need_img_original
else
img
)
vae_encode_out
,
latent_shape
=
self
.
run_vae_encoder
(
img_ori
if
self
.
vae_encoder_need_img_original
else
img
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
img
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_t2v
(
self
):
def
_run_input_encoder_local_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
self
.
input_info
.
latent_shape
=
self
.
get_latent_shape_with_target_hw
(
self
.
config
[
"target_height"
],
self
.
config
[
"target_width"
])
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
return
{
...
@@ -188,22 +195,21 @@ class DefaultRunner(BaseRunner):
...
@@ -188,22 +195,21 @@ class DefaultRunner(BaseRunner):
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_flf2v
(
self
):
def
_run_input_encoder_local_flf2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
first_frame
,
_
=
self
.
read_image_input
(
self
.
input_info
.
image_path
)
first_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
last_frame
,
_
=
self
.
read_image_input
(
self
.
input_info
.
last_frame_path
)
last_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"last_frame_path"
])
clip_encoder_out
=
self
.
run_image_encoder
(
first_frame
,
last_frame
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
run_image_encoder
(
first_frame
,
last_frame
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
first_frame
,
last_frame
)
vae_encode_out
,
latent_shape
=
self
.
run_vae_encoder
(
first_frame
,
last_frame
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
first_frame
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_vace
(
self
):
def
_run_input_encoder_local_vace
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
src_video
=
self
.
input_info
.
src_video
src_video
=
self
.
config
.
get
(
"src_video"
,
None
)
src_mask
=
self
.
input_info
.
src_mask
src_mask
=
self
.
config
.
get
(
"src_mask"
,
None
)
src_ref_images
=
self
.
input_info
.
src_ref_images
src_ref_images
=
self
.
config
.
get
(
"src_ref_images"
,
None
)
src_video
,
src_mask
,
src_ref_images
=
self
.
prepare_source
(
src_video
,
src_mask
,
src_ref_images
=
self
.
prepare_source
(
[
src_video
],
[
src_video
],
[
src_mask
],
[
src_mask
],
...
@@ -212,34 +218,38 @@ class DefaultRunner(BaseRunner):
...
@@ -212,34 +218,38 @@ class DefaultRunner(BaseRunner):
)
)
self
.
src_ref_images
=
src_ref_images
self
.
src_ref_images
=
src_ref_images
vae_encoder_out
=
self
.
run_vae_encoder
(
src_video
,
src_ref_images
,
src_mask
)
vae_encoder_out
,
latent_shape
=
self
.
run_vae_encoder
(
src_video
,
src_ref_images
,
src_mask
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
)
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
vae_encoder_out
,
text_encoder_output
)
return
self
.
get_encoder_output_i2v
(
None
,
vae_encoder_out
,
text_encoder_output
)
@
ProfilingContext4DebugL2
(
"Run Text Encoder"
)
@
ProfilingContext4DebugL2
(
"Run Text Encoder"
)
def
_run_input_encoder_local_animate
(
self
):
def
_run_input_encoder_local_animate
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
None
,
text_encoder_output
,
None
)
return
self
.
get_encoder_output_i2v
(
None
,
None
,
text_encoder_output
,
None
)
def
_run_input_encoder_local_s2v
(
self
):
pass
def
init_run
(
self
):
def
init_run
(
self
):
self
.
set_target_shape
()
self
.
gen_video_final
=
None
self
.
get_video_segment_num
()
self
.
get_video_segment_num
()
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
model
.
scheduler
.
prepare
(
seed
=
self
.
input_info
.
seed
,
latent_shape
=
self
.
input_info
.
latent_shape
,
image_encoder_output
=
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
@
ProfilingContext4DebugL2
(
"Run DiT"
)
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
def
run_main
(
self
,
total_steps
=
None
):
self
.
init_run
()
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
()
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
for
segment_idx
in
range
(
self
.
video_segment_num
):
for
segment_idx
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"🔄 start segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
logger
.
info
(
f
"🔄 start segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
with
ProfilingContext4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
with
ProfilingContext4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
...
@@ -252,7 +262,9 @@ class DefaultRunner(BaseRunner):
...
@@ -252,7 +262,9 @@ class DefaultRunner(BaseRunner):
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
# 4. default do nothing
self
.
end_run_segment
(
segment_idx
)
self
.
end_run_segment
(
segment_idx
)
gen_video_final
=
self
.
process_images_after_vae_decoder
()
self
.
end_run
()
self
.
end_run
()
return
{
"video"
:
gen_video_final
}
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
...
@@ -281,20 +293,22 @@ class DefaultRunner(BaseRunner):
...
@@ -281,20 +293,22 @@ class DefaultRunner(BaseRunner):
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
return
enhanced_prompt
return
enhanced_prompt
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
def
process_images_after_vae_decoder
(
self
):
self
.
gen_video
=
vae_to_comfyui_image
(
self
.
gen_video
)
self
.
gen_video
_final
=
vae_to_comfyui_image
(
self
.
gen_video
_final
)
if
"video_frame_interpolation"
in
self
.
config
:
if
"video_frame_interpolation"
in
self
.
config
:
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
self
.
config
.
get
(
'fps'
,
16
)
}
to
{
target_fps
}
"
)
logger
.
info
(
f
"Interpolating frames from
{
self
.
config
.
get
(
'fps'
,
16
)
}
to
{
target_fps
}
"
)
self
.
gen_video
=
self
.
vfi_model
.
interpolate_frames
(
self
.
gen_video
_final
=
self
.
vfi_model
.
interpolate_frames
(
self
.
gen_video
,
self
.
gen_video
_final
,
source_fps
=
self
.
config
.
get
(
"fps"
,
16
),
source_fps
=
self
.
config
.
get
(
"fps"
,
16
),
target_fps
=
target_fps
,
target_fps
=
target_fps
,
)
)
if
save_video
:
if
self
.
input_info
.
return_result_tensor
:
return
{
"video"
:
self
.
gen_video_final
}
elif
self
.
input_info
.
save_result_path
is
not
None
:
if
"video_frame_interpolation"
in
self
.
config
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
):
if
"video_frame_interpolation"
in
self
.
config
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
):
fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
else
:
else
:
...
@@ -303,22 +317,18 @@ class DefaultRunner(BaseRunner):
...
@@ -303,22 +317,18 @@ class DefaultRunner(BaseRunner):
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"🎬 Start to save video 🎬"
)
logger
.
info
(
f
"🎬 Start to save video 🎬"
)
save_to_video
(
self
.
gen_video
,
self
.
config
.
save_video_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
save_to_video
(
self
.
gen_video_final
,
self
.
input_info
.
save_result_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
logger
.
info
(
f
"✅ Video saved successfully to:
{
self
.
config
.
save_video_path
}
✅"
)
logger
.
info
(
f
"✅ Video saved successfully to:
{
self
.
input_info
.
save_result_path
}
✅"
)
if
self
.
config
.
get
(
"return_video"
,
False
):
return
{
"video"
:
None
}
return
{
"video"
:
self
.
gen_video
}
return
{
"video"
:
None
}
def
run_pipeline
(
self
,
input_info
):
self
.
input_info
=
input_info
def
run_pipeline
(
self
,
save_video
=
True
):
if
self
.
config
[
"use_prompt_enhancer"
]:
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"
prompt_enhanced
"
]
=
self
.
post_prompt_enhancer
()
self
.
input_info
.
prompt_enhanced
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
run_main
()
gen_video
=
self
.
process_images_after_vae_decoder
(
save_video
=
save_video
)
gen_video_final
=
self
.
run_main
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
gen_video
return
gen_video
_final
lightx2v/models/runners/hunyuan/hunyuan_runner.py
View file @
04812de2
...
@@ -14,7 +14,6 @@ from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
...
@@ -14,7 +14,6 @@ from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from
lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae
import
HunyuanVAE
from
lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae
import
HunyuanVAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_videos_grid
@
RUNNER_REGISTER
(
"hunyuan"
)
@
RUNNER_REGISTER
(
"hunyuan"
)
...
@@ -152,6 +151,3 @@ class HunyuanRunner(DefaultRunner):
...
@@ -152,6 +151,3 @@ class HunyuanRunner(DefaultRunner):
int
(
self
.
config
.
target_width
)
//
vae_scale_factor
,
int
(
self
.
config
.
target_width
)
//
vae_scale_factor
,
)
)
return
{
"target_height"
:
self
.
config
.
target_height
,
"target_width"
:
self
.
config
.
target_width
,
"target_shape"
:
self
.
config
.
target_shape
}
return
{
"target_height"
:
self
.
config
.
target_height
,
"target_width"
:
self
.
config
.
target_width
,
"target_shape"
:
self
.
config
.
target_shape
}
def
save_video_func
(
self
,
images
):
save_videos_grid
(
images
,
self
.
config
.
save_video_path
,
fps
=
self
.
config
.
get
(
"fps"
,
24
))
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
04812de2
...
@@ -204,7 +204,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -204,7 +204,7 @@ class QwenImageRunner(DefaultRunner):
images
=
self
.
run_vae_decoder
(
latents
,
generator
)
images
=
self
.
run_vae_decoder
(
latents
,
generator
)
image
=
images
[
0
]
image
=
images
[
0
]
image
.
save
(
f
"
{
self
.
config
.
save_
video
_path
}
"
)
image
.
save
(
f
"
{
self
.
config
.
save_
result
_path
}
"
)
del
latents
,
generator
del
latents
,
generator
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/runners/wan/wan_animate_runner.py
View file @
04812de2
...
@@ -334,7 +334,7 @@ class WanAnimateRunner(WanRunner):
...
@@ -334,7 +334,7 @@ class WanAnimateRunner(WanRunner):
)
)
if
start
!=
0
:
if
start
!=
0
:
self
.
model
.
scheduler
.
reset
()
self
.
model
.
scheduler
.
reset
(
self
.
input_info
.
seed
,
self
.
input_info
.
latent_shape
)
def
end_run_segment
(
self
,
segment_idx
):
def
end_run_segment
(
self
,
segment_idx
):
if
segment_idx
!=
0
:
if
segment_idx
!=
0
:
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
04812de2
import
gc
import
gc
import
json
import
os
import
os
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
...
@@ -337,16 +337,14 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -337,16 +337,14 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Initialize consistency model scheduler"""
"""Initialize consistency model scheduler"""
self
.
scheduler
=
EulerScheduler
(
self
.
config
)
self
.
scheduler
=
EulerScheduler
(
self
.
config
)
def
read_audio_input
(
self
):
def
read_audio_input
(
self
,
audio_path
):
"""Read audio input - handles both single and multi-person scenarios"""
"""Read audio input - handles both single and multi-person scenarios"""
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
# Get audio files from person objects or legacy format
# Get audio files from person objects or legacy format
audio_files
=
self
.
_get_audio_files_from_config
()
audio_files
,
mask_files
=
self
.
get_audio_files_from_audio_path
(
audio_path
)
if
not
audio_files
:
return
[],
0
# Load audio based on single or multi-person mode
# Load audio based on single or multi-person mode
if
len
(
audio_files
)
==
1
:
if
len
(
audio_files
)
==
1
:
...
@@ -355,8 +353,6 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -355,8 +353,6 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
else
:
audio_array
=
self
.
_audio_processor
.
load_multi_person_audio
(
audio_files
)
audio_array
=
self
.
_audio_processor
.
load_multi_person_audio
(
audio_files
)
self
.
config
.
audio_num
=
audio_array
.
size
(
0
)
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
audio_len
=
int
(
audio_array
.
shape
[
1
]
/
audio_sr
*
target_fps
)
audio_len
=
int
(
audio_array
.
shape
[
1
]
/
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
...
@@ -364,60 +360,35 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -364,60 +360,35 @@ class WanAudioRunner(WanRunner): # type:ignore
# Segment audio
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
self
.
config
.
get
(
"target_video_length"
,
81
),
self
.
prev_frame_length
)
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
self
.
config
.
get
(
"target_video_length"
,
81
),
self
.
prev_frame_length
)
return
audio_array
.
size
(
0
),
audio_segments
,
expected_frames
# Mask latent for multi-person s2v
if
mask_files
is
not
None
:
def
_get_audio_files_from_config
(
self
):
mask_latents
=
[
self
.
process_single_mask
(
mask_file
)
for
mask_file
in
mask_files
]
talk_objects
=
self
.
config
.
get
(
"talk_objects"
)
mask_latents
=
torch
.
cat
(
mask_latents
,
dim
=
0
)
if
talk_objects
:
else
:
audio_files
=
[]
mask_latents
=
None
for
idx
,
person
in
enumerate
(
talk_objects
):
audio_path
=
person
.
get
(
"audio"
)
if
audio_path
and
Path
(
audio_path
).
is_file
():
audio_files
.
append
(
str
(
audio_path
))
else
:
logger
.
warning
(
f
"Person
{
idx
}
audio file
{
audio_path
}
does not exist or not specified"
)
if
audio_files
:
logger
.
info
(
f
"Loaded
{
len
(
audio_files
)
}
audio files from talk_objects"
)
return
audio_files
audio_path
=
self
.
config
.
get
(
"audio_path"
)
if
audio_path
:
return
[
audio_path
]
logger
.
error
(
"config audio_path or talk_objects is not specified"
)
return
[]
def
read_person_mask
(
self
):
return
audio_segments
,
expected_frames
,
mask_latents
,
len
(
audio_files
)
mask_files
=
self
.
_get_mask_files_from_config
()
if
not
mask_files
:
return
None
mask_latents
=
[]
def
get_audio_files_from_audio_path
(
self
,
audio_path
):
for
mask_file
in
mask_files
:
if
os
.
path
.
isdir
(
audio_path
):
mask_latent
=
self
.
_process_single_mask
(
mask_file
)
audio_files
=
[]
mask_latents
.
append
(
mask_latent
)
mask_files
=
[]
logger
.
info
(
f
"audio_path is a directory, loading config.json from
{
audio_path
}
"
)
audio_config_path
=
os
.
path
.
join
(
audio_path
,
"config.json"
)
assert
os
.
path
.
exists
(
audio_config_path
),
"config.json not found in audio_path"
with
open
(
audio_config_path
,
"r"
)
as
f
:
audio_config
=
json
.
load
(
f
)
for
talk_object
in
audio_config
[
"talk_objects"
]:
audio_files
.
append
(
os
.
path
.
join
(
audio_path
,
talk_object
[
"audio"
]))
mask_files
.
append
(
os
.
path
.
join
(
audio_path
,
talk_object
[
"mask"
]))
else
:
logger
.
info
(
f
"audio_path is a file without mask:
{
audio_path
}
"
)
audio_files
=
[
audio_path
]
mask_files
=
None
mask_latents
=
torch
.
cat
(
mask_latents
,
dim
=
0
)
return
audio_files
,
mask_files
return
mask_latents
def
_get_mask_files_from_config
(
self
):
def
process_single_mask
(
self
,
mask_file
):
talk_objects
=
self
.
config
.
get
(
"talk_objects"
)
if
talk_objects
:
mask_files
=
[]
for
idx
,
person
in
enumerate
(
talk_objects
):
mask_path
=
person
.
get
(
"mask"
)
if
mask_path
and
Path
(
mask_path
).
is_file
():
mask_files
.
append
(
str
(
mask_path
))
elif
mask_path
:
logger
.
warning
(
f
"Person
{
idx
}
mask file
{
mask_path
}
does not exist"
)
if
mask_files
:
logger
.
info
(
f
"Loaded
{
len
(
mask_files
)
}
mask files from talk_objects"
)
return
mask_files
logger
.
info
(
"config talk_objects is not specified"
)
return
None
def
_process_single_mask
(
self
,
mask_file
):
mask_img
=
Image
.
open
(
mask_file
).
convert
(
"RGB"
)
mask_img
=
Image
.
open
(
mask_file
).
convert
(
"RGB"
)
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
...
@@ -456,21 +427,21 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -456,21 +427,21 @@ class WanAudioRunner(WanRunner): # type:ignore
fixed_shape
=
self
.
config
.
get
(
"fixed_shape"
,
None
),
fixed_shape
=
self
.
config
.
get
(
"fixed_shape"
,
None
),
)
)
logger
.
info
(
f
"[wan_audio] resize_image target_h:
{
h
}
, target_w:
{
w
}
"
)
logger
.
info
(
f
"[wan_audio] resize_image target_h:
{
h
}
, target_w:
{
w
}
"
)
patched_h
=
h
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
patched_h
=
h
//
self
.
config
[
"
vae_stride
"
]
[
1
]
//
self
.
config
[
"
patch_size
"
]
[
1
]
patched_w
=
w
//
self
.
config
.
vae_stride
[
2
]
//
self
.
config
.
patch_size
[
2
]
patched_w
=
w
//
self
.
config
[
"
vae_stride
"
]
[
2
]
//
self
.
config
[
"
patch_size
"
]
[
2
]
patched_h
,
patched_w
=
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
1
)
patched_h
,
patched_w
=
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
1
)
self
.
config
.
la
t_h
=
patched_h
*
self
.
config
.
patch_size
[
1
]
laten
t_h
=
patched_h
*
self
.
config
[
"
patch_size
"
]
[
1
]
self
.
config
.
la
t_w
=
patched_w
*
self
.
config
.
patch_size
[
2
]
laten
t_w
=
patched_w
*
self
.
config
[
"
patch_size
"
]
[
2
]
self
.
config
.
tgt_h
=
self
.
config
.
lat_h
*
self
.
config
.
vae_stride
[
1
]
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
self
.
config
.
tgt_w
=
self
.
config
.
la
t_w
*
self
.
config
.
vae_stride
[
2
]
target_shape
=
[
latent_h
*
self
.
config
[
"vae_stride"
][
1
],
laten
t_w
*
self
.
config
[
"
vae_stride
"
]
[
2
]
]
logger
.
info
(
f
"[wan_audio] t
g
t_h:
{
self
.
config
.
tgt_h
}
, tgt_w:
{
self
.
config
.
tgt_w
}
, lat_h:
{
self
.
config
.
la
t_h
}
, lat_w:
{
self
.
config
.
la
t_w
}
"
)
logger
.
info
(
f
"[wan_audio] t
arge
t_h:
{
target_shape
[
0
]
}
, target_w:
{
target_shape
[
1
]
}
, lat
ent
_h:
{
laten
t_h
}
, lat
ent
_w:
{
laten
t_w
}
"
)
ref_img
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
),
mode
=
"bicubic"
)
ref_img
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
target_shape
[
0
],
target_shape
[
1
]
),
mode
=
"bicubic"
)
return
ref_img
return
ref_img
,
latent_shape
,
target_shape
def
run_image_encoder
(
self
,
first_frame
,
last_frame
=
None
):
def
run_image_encoder
(
self
,
first_frame
,
last_frame
=
None
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
...
@@ -496,20 +467,17 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -496,20 +467,17 @@ class WanAudioRunner(WanRunner): # type:ignore
return
vae_encoder_out
return
vae_encoder_out
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
@
ProfilingContext4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_r2v_audio
(
self
):
def
_run_input_encoder_local_s2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
,
latent_shape
,
target_shape
=
self
.
read_image_input
(
self
.
input_info
.
image_path
)
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
self
.
input_info
.
latent_shape
=
latent_shape
# Important: set latent_shape in input_info
self
.
input_info
.
target_shape
=
target_shape
# Important: set target_shape in input_info
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img
)
vae_encode_out
=
self
.
run_vae_encoder
(
img
)
audio_num
,
audio_segments
,
expected_frames
=
self
.
read_audio_input
()
audio_segments
,
expected_frames
,
person_mask_latens
,
audio_num
=
self
.
read_audio_input
(
self
.
input_info
.
audio_path
)
person_mask_latens
=
self
.
read_person_mask
()
self
.
input_info
.
audio_num
=
audio_num
self
.
config
.
person_num
=
0
self
.
input_info
.
with_mask
=
person_mask_latens
is
not
None
if
person_mask_latens
is
not
None
:
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
assert
audio_num
==
person_mask_latens
.
size
(
0
),
"audio_num and person_mask_latens.size(0) must be the same"
self
.
config
.
person_num
=
person_mask_latens
.
size
(
0
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
return
{
...
@@ -528,13 +496,13 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -528,13 +496,13 @@ class WanAudioRunner(WanRunner): # type:ignore
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
GET_DTYPE
()
dtype
=
GET_DTYPE
()
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
tgt_h
,
tgt_w
=
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
[
"
target_video_length
"
]
,
tgt_h
,
tgt_w
),
device
=
device
)
if
prev_video
is
not
None
:
if
prev_video
is
not
None
:
# Extract and process last frames
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
!=
"wan2.2_audio"
:
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
...
@@ -546,7 +514,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -546,7 +514,7 @@ class WanAudioRunner(WanRunner): # type:ignore
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
with
ProfilingContext4DebugL1
(
"vae_encoder in init run segment"
):
with
ProfilingContext4DebugL1
(
"vae_encoder in init run segment"
):
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
==
"wan2.2_audio"
:
if
prev_video
is
not
None
:
if
prev_video
is
not
None
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
else
:
else
:
...
@@ -563,7 +531,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -563,7 +531,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
prev_latents
is
not
None
:
if
prev_latents
is
not
None
:
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
tgt_h
}
, tgt_w=
{
tgt_w
}
"
)
prev_latents
=
torch
.
nn
.
functional
.
interpolate
(
prev_latents
,
size
=
(
height
,
width
),
mode
=
"bilinear"
,
align_corners
=
False
)
prev_latents
=
torch
.
nn
.
functional
.
interpolate
(
prev_latents
,
size
=
(
height
,
width
),
mode
=
"bilinear"
,
align_corners
=
False
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
...
@@ -592,8 +560,8 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -592,8 +560,8 @@ class WanAudioRunner(WanRunner): # type:ignore
super
().
init_run
()
super
().
init_run
()
self
.
scheduler
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
scheduler
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
prev_video
=
None
self
.
prev_video
=
None
if
self
.
config
.
get
(
"return_video"
,
False
)
:
if
self
.
input_info
.
return_result_tensor
:
self
.
gen_video_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
],
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
,
3
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
gen_video_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
],
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
,
3
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
cut_audio_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
]
*
self
.
_audio_processor
.
audio_frame_rate
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
cut_audio_final
=
torch
.
zeros
((
self
.
inputs
[
"expected_frames"
]
*
self
.
_audio_processor
.
audio_frame_rate
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
else
:
else
:
self
.
gen_video_final
=
None
self
.
gen_video_final
=
None
...
@@ -608,8 +576,8 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -608,8 +576,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
else
:
self
.
segment
=
self
.
inputs
[
"audio_segments"
][
segment_idx
]
self
.
segment
=
self
.
inputs
[
"audio_segments"
][
segment_idx
]
self
.
config
.
seed
=
self
.
config
.
seed
+
segment_idx
self
.
input_info
.
seed
=
self
.
input_info
.
seed
+
segment_idx
torch
.
manual_seed
(
self
.
config
.
seed
)
torch
.
manual_seed
(
self
.
input_info
.
seed
)
# logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
# logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
if
(
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
))
and
not
hasattr
(
self
,
"audio_encoder"
):
if
(
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
))
and
not
hasattr
(
self
,
"audio_encoder"
):
...
@@ -627,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -627,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# Reset scheduler for non-first segments
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
self
.
model
.
scheduler
.
reset
(
self
.
input_info
.
seed
,
self
.
input_info
.
latent_shape
,
self
.
inputs
[
"previmg_encoder_output"
])
@
ProfilingContext4DebugL1
(
"End run segment"
)
@
ProfilingContext4DebugL1
(
"End run segment"
)
def
end_run_segment
(
self
,
segment_idx
):
def
end_run_segment
(
self
,
segment_idx
):
...
@@ -650,7 +618,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -650,7 +618,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
self
.
va_recorder
:
if
self
.
va_recorder
:
self
.
va_recorder
.
pub_livestream
(
video_seg
,
audio_seg
)
self
.
va_recorder
.
pub_livestream
(
video_seg
,
audio_seg
)
elif
self
.
config
.
get
(
"return_video"
,
False
)
:
elif
self
.
input_info
.
return_result_tensor
:
self
.
gen_video_final
[
self
.
segment
.
start_frame
:
self
.
segment
.
end_frame
].
copy_
(
video_seg
)
self
.
gen_video_final
[
self
.
segment
.
start_frame
:
self
.
segment
.
end_frame
].
copy_
(
video_seg
)
self
.
cut_audio_final
[
self
.
segment
.
start_frame
*
self
.
_audio_processor
.
audio_frame_rate
:
self
.
segment
.
end_frame
*
self
.
_audio_processor
.
audio_frame_rate
].
copy_
(
audio_seg
)
self
.
cut_audio_final
[
self
.
segment
.
start_frame
*
self
.
_audio_processor
.
audio_frame_rate
:
self
.
segment
.
end_frame
*
self
.
_audio_processor
.
audio_frame_rate
].
copy_
(
audio_seg
)
...
@@ -669,7 +637,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -669,7 +637,7 @@ class WanAudioRunner(WanRunner): # type:ignore
return
rank
,
world_size
return
rank
,
world_size
def
init_va_recorder
(
self
):
def
init_va_recorder
(
self
):
output_video_path
=
self
.
config
.
get
(
"save_video_path"
,
None
)
output_video_path
=
self
.
input_info
.
save_result_path
self
.
va_recorder
=
None
self
.
va_recorder
=
None
if
isinstance
(
output_video_path
,
dict
):
if
isinstance
(
output_video_path
,
dict
):
output_video_path
=
output_video_path
[
"data"
]
output_video_path
=
output_video_path
[
"data"
]
...
@@ -722,7 +690,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -722,7 +690,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
init_run
()
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
()
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
self
.
video_segment_num
=
"unlimited"
self
.
video_segment_num
=
"unlimited"
fetch_timeout
=
self
.
va_reader
.
segment_duration
+
1
fetch_timeout
=
self
.
va_reader
.
segment_duration
+
1
...
@@ -760,24 +728,20 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -760,24 +728,20 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
va_recorder
=
None
self
.
va_recorder
=
None
@
ProfilingContext4DebugL1
(
"Process after vae decoder"
)
@
ProfilingContext4DebugL1
(
"Process after vae decoder"
)
def
process_images_after_vae_decoder
(
self
,
save_video
=
False
):
def
process_images_after_vae_decoder
(
self
):
if
self
.
config
.
get
(
"return_video"
,
False
)
:
if
self
.
input_info
.
return_result_tensor
:
audio_waveform
=
self
.
cut_audio_final
.
unsqueeze
(
0
).
unsqueeze
(
0
)
audio_waveform
=
self
.
cut_audio_final
.
unsqueeze
(
0
).
unsqueeze
(
0
)
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
return
{
"video"
:
self
.
gen_video_final
,
"audio"
:
comfyui_audio
}
return
{
"video"
:
self
.
gen_video_final
,
"audio"
:
comfyui_audio
}
return
{
"video"
:
None
,
"audio"
:
None
}
return
{
"video"
:
None
,
"audio"
:
None
}
def
init_modules
(
self
):
super
().
init_modules
()
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_r2v_audio
def
load_transformer
(
self
):
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
base_model
=
WanAudioModel
(
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
lora_path
=
lora_config
[
"path"
]
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
...
@@ -814,7 +778,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -814,7 +778,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter
.
to
(
device
)
audio_adapter
.
to
(
device
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
audio_adapter_offload
,
remove_key
=
"ca"
,
load_from_rank0
=
load_from_rank0
)
weights_dict
=
load_weights
(
self
.
config
[
"
adapter_model_path
"
]
,
cpu_offload
=
audio_adapter_offload
,
remove_key
=
"ca"
,
load_from_rank0
=
load_from_rank0
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
...
@@ -824,28 +788,14 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -824,28 +788,14 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
audio_adapter
=
self
.
load_audio_adapter
()
self
.
audio_adapter
=
self
.
load_audio_adapter
()
def
set_target_shape
(
self
):
def
get_latent_shape_with_lat_hw
(
self
,
latent_h
,
latent_w
):
"""Set target shape for generation"""
latent_shape
=
[
ret
=
{}
self
.
config
.
get
(
"num_channels_latents"
,
16
),
num_channels_latents
=
16
(
self
.
config
[
"target_video_length"
]
-
1
)
//
self
.
config
[
"vae_stride"
][
0
]
+
1
,
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
latent_h
,
num_channels_latents
=
self
.
config
.
num_channels_latents
latent_w
,
]
if
self
.
config
.
task
==
"i2v"
:
return
latent_shape
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
)
ret
[
"lat_h"
]
=
self
.
config
.
lat_h
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
False
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
@
RUNNER_REGISTER
(
"wan2.2_audio"
)
@
RUNNER_REGISTER
(
"wan2.2_audio"
)
...
@@ -882,7 +832,7 @@ class Wan22AudioRunner(WanAudioRunner):
...
@@ -882,7 +832,7 @@ class Wan22AudioRunner(WanAudioRunner):
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
}
if
self
.
config
.
task
!=
"i2v"
:
if
self
.
config
.
task
not
in
[
"i2v"
,
"s2v"
]
:
return
None
return
None
else
:
else
:
return
Wan2_2_VAE
(
**
vae_config
)
return
Wan2_2_VAE
(
**
vae_config
)
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
04812de2
...
@@ -50,7 +50,7 @@ class WanCausVidRunner(WanRunner):
...
@@ -50,7 +50,7 @@ class WanCausVidRunner(WanRunner):
self
.
scheduler
=
WanStepDistillScheduler
(
self
.
config
)
self
.
scheduler
=
WanStepDistillScheduler
(
self
.
config
)
def
set_target_shape
(
self
):
def
set_target_shape
(
self
):
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
in
[
"i2v"
,
"s2v"
]
:
self
.
config
.
target_shape
=
(
16
,
self
.
config
.
num_frame_per_block
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
)
self
.
config
.
target_shape
=
(
16
,
self
.
config
.
num_frame_per_block
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
)
# i2v需根据input shape重置frame_seq_length
# i2v需根据input shape重置frame_seq_length
frame_seq_length
=
(
self
.
config
.
lat_h
//
2
)
*
(
self
.
config
.
lat_w
//
2
)
frame_seq_length
=
(
self
.
config
.
lat_h
//
2
)
*
(
self
.
config
.
lat_w
//
2
)
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
04812de2
...
@@ -17,28 +17,28 @@ class WanDistillRunner(WanRunner):
...
@@ -17,28 +17,28 @@ class WanDistillRunner(WanRunner):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
def
load_transformer
(
self
):
def
load_transformer
(
self
):
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
model
=
WanModel
(
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
lora_path
=
lora_config
[
"path"
]
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
else
:
model
=
WanDistillModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
model
=
WanDistillModel
(
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
init_device
)
return
model
return
model
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
[
"
feature_caching
"
]
==
"NoCaching"
:
self
.
scheduler
=
WanStepDistillScheduler
(
self
.
config
)
self
.
scheduler
=
WanStepDistillScheduler
(
self
.
config
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'
feature_caching
'
]
}
"
)
class
MultiDistillModelStruct
(
MultiModelStruct
):
class
MultiDistillModelStruct
(
MultiModelStruct
):
...
@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct):
...
@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct):
def
get_current_model_index
(
self
):
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
step_index
<
self
.
boundary_step_index
:
if
self
.
scheduler
.
step_index
<
self
.
boundary_step_index
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
0
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
0
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
0
)
self
.
to_cuda
(
model_index
=
0
)
...
@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct):
...
@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self
.
cur_model_index
=
0
self
.
cur_model_index
=
0
else
:
else
:
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
1
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
1
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
1
)
self
.
to_cuda
(
model_index
=
1
)
...
@@ -81,8 +81,8 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -81,8 +81,8 @@ class Wan22MoeDistillRunner(WanDistillRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
use_high_lora
,
use_low_lora
=
False
,
False
use_high_lora
,
use_low_lora
=
False
,
False
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
use_high_lora
=
True
use_high_lora
=
True
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
...
@@ -90,12 +90,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -90,12 +90,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
if
use_high_lora
:
if
use_high_lora
:
high_noise_model
=
WanModel
(
high_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"high_noise_model"
),
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
lora_path
=
lora_config
[
"path"
]
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
...
@@ -104,19 +104,19 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -104,19 +104,19 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
else
:
high_noise_model
=
Wan22MoeDistillModel
(
high_noise_model
=
Wan22MoeDistillModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"distill_models"
,
"high_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"distill_models"
,
"high_noise_model"
),
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
if
use_low_lora
:
if
use_low_lora
:
low_noise_model
=
WanModel
(
low_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"low_noise_model"
),
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
lora_path
=
lora_config
[
"path"
]
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
...
@@ -125,15 +125,15 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -125,15 +125,15 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
else
:
low_noise_model
=
Wan22MoeDistillModel
(
low_noise_model
=
Wan22MoeDistillModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"distill_models"
,
"low_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"distill_models"
,
"low_noise_model"
),
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary_step_index
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"
boundary_step_index
"
]
)
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
[
"
feature_caching
"
]
==
"NoCaching"
:
self
.
scheduler
=
Wan22StepDistillScheduler
(
self
.
config
)
self
.
scheduler
=
Wan22StepDistillScheduler
(
self
.
config
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'
feature_caching
'
]
}
"
)
lightx2v/models/runners/wan/wan_runner.py
View file @
04812de2
...
@@ -28,7 +28,7 @@ from lightx2v.utils.envs import *
...
@@ -28,7 +28,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
from
lightx2v.utils.utils
import
best_output_size
@
RUNNER_REGISTER
(
"wan2.1"
)
@
RUNNER_REGISTER
(
"wan2.1"
)
...
@@ -42,7 +42,7 @@ class WanRunner(DefaultRunner):
...
@@ -42,7 +42,7 @@ class WanRunner(DefaultRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
model
=
WanModel
(
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
...
@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner):
...
@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner):
def
load_image_encoder
(
self
):
def
load_image_encoder
(
self
):
image_encoder
=
None
image_encoder
=
None
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
[
"
task
"
]
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
# offload config
# offload config
clip_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
clip_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
if
clip_offload
:
if
clip_offload
:
...
@@ -148,13 +148,13 @@ class WanRunner(DefaultRunner):
...
@@ -148,13 +148,13 @@ class WanRunner(DefaultRunner):
vae_config
=
{
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
,
"parallel"
:
self
.
config
[
"
parallel
"
]
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"dtype"
:
GET_DTYPE
(),
"load_from_rank0"
:
self
.
config
.
get
(
"load_from_rank0"
,
False
),
"load_from_rank0"
:
self
.
config
.
get
(
"load_from_rank0"
,
False
),
}
}
if
self
.
config
.
task
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
]:
if
self
.
config
[
"
task
"
]
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
,
"s2v"
]:
return
None
return
None
else
:
else
:
return
self
.
vae_cls
(
**
vae_config
)
return
self
.
vae_cls
(
**
vae_config
)
...
@@ -170,7 +170,7 @@ class WanRunner(DefaultRunner):
...
@@ -170,7 +170,7 @@ class WanRunner(DefaultRunner):
vae_config
=
{
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
self
.
vae_name
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
,
"parallel"
:
self
.
config
[
"
parallel
"
]
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"dtype"
:
GET_DTYPE
(),
...
@@ -192,9 +192,9 @@ class WanRunner(DefaultRunner):
...
@@ -192,9 +192,9 @@ class WanRunner(DefaultRunner):
return
vae_encoder
,
vae_decoder
return
vae_encoder
,
vae_decoder
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
[
"
feature_caching
"
]
==
"NoCaching"
:
scheduler_class
=
WanScheduler
scheduler_class
=
WanScheduler
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
elif
self
.
config
[
"
feature_caching
"
]
==
"TaylorSeer"
:
scheduler_class
=
WanSchedulerTaylorCaching
scheduler_class
=
WanSchedulerTaylorCaching
elif
self
.
config
.
feature_caching
in
[
"Tea"
,
"Ada"
,
"Custom"
,
"FirstBlock"
,
"DualBlock"
,
"DynamicBlock"
,
"Mag"
]:
elif
self
.
config
.
feature_caching
in
[
"Tea"
,
"Ada"
,
"Custom"
,
"FirstBlock"
,
"DualBlock"
,
"DynamicBlock"
,
"Mag"
]:
scheduler_class
=
WanSchedulerCaching
scheduler_class
=
WanSchedulerCaching
...
@@ -206,26 +206,28 @@ class WanRunner(DefaultRunner):
...
@@ -206,26 +206,28 @@ class WanRunner(DefaultRunner):
else
:
else
:
self
.
scheduler
=
scheduler_class
(
self
.
config
)
self
.
scheduler
=
scheduler_class
(
self
.
config
)
def
run_text_encoder
(
self
,
text
,
img
=
None
):
def
run_text_encoder
(
self
,
input_info
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
prompt
=
input_info
.
prompt_enhanced
if
self
.
config
[
"use_prompt_enhancer"
]
else
input_info
.
prompt
neg_prompt
=
input_info
.
negative_prompt
if
self
.
config
[
"cfg_parallel"
]:
if
self
.
config
[
"cfg_parallel"
]:
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
cfg_p_group
=
self
.
config
[
"device_mesh"
].
get_group
(
mesh_dim
=
"cfg_p"
)
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
cfg_p_rank
=
dist
.
get_rank
(
cfg_p_group
)
if
cfg_p_rank
==
0
:
if
cfg_p_rank
==
0
:
context
=
self
.
text_encoders
[
0
].
infer
([
tex
t
])
context
=
self
.
text_encoders
[
0
].
infer
([
promp
t
])
context
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
context
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
text_encoder_output
=
{
"context"
:
context
}
text_encoder_output
=
{
"context"
:
context
}
else
:
else
:
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n
eg
_prompt
])
context_null
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context_null
])
context_null
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context_null
])
text_encoder_output
=
{
"context_null"
:
context_null
}
text_encoder_output
=
{
"context_null"
:
context_null
}
else
:
else
:
context
=
self
.
text_encoders
[
0
].
infer
([
tex
t
])
context
=
self
.
text_encoders
[
0
].
infer
([
promp
t
])
context
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
context
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n
eg
_prompt
])
context_null
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context_null
])
context_null
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
config
[
"text_len"
]
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context_null
])
text_encoder_output
=
{
text_encoder_output
=
{
"context"
:
context
,
"context"
:
context
,
...
@@ -255,22 +257,22 @@ class WanRunner(DefaultRunner):
...
@@ -255,22 +257,22 @@ class WanRunner(DefaultRunner):
def
run_vae_encoder
(
self
,
first_frame
,
last_frame
=
None
):
def
run_vae_encoder
(
self
,
first_frame
,
last_frame
=
None
):
h
,
w
=
first_frame
.
shape
[
2
:]
h
,
w
=
first_frame
.
shape
[
2
:]
aspect_ratio
=
h
/
w
aspect_ratio
=
h
/
w
max_area
=
self
.
config
.
target_height
*
self
.
config
.
target_width
max_area
=
self
.
config
[
"target_height"
]
*
self
.
config
[
"target_width"
]
lat_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
1
])
latent_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
1
]
//
self
.
config
[
"patch_size"
][
1
]
*
self
.
config
[
"patch_size"
][
1
])
lat_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
self
.
config
.
vae_stride
[
2
]
//
self
.
config
.
patch_size
[
2
]
*
self
.
config
.
patch_size
[
2
])
latent_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
2
]
//
self
.
config
[
"patch_size"
][
2
]
*
self
.
config
[
"patch_size"
][
2
])
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
# Important: latent_shape is used to set the input_info
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
assert
last_frame
is
None
assert
last_frame
is
None
self
.
config
.
lat_h
,
self
.
config
.
lat_w
=
lat_h
,
lat_w
vae_encode_out_list
=
[]
vae_encode_out_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
lat
_h
,
lat_w
=
(
lat
ent_h_tmp
,
latent_w_tmp
=
(
int
(
self
.
config
.
la
t_h
*
self
.
config
.
resolution_rate
[
i
])
//
2
*
2
,
int
(
laten
t_h
*
self
.
config
[
"
resolution_rate
"
]
[
i
])
//
2
*
2
,
int
(
self
.
config
.
la
t_w
*
self
.
config
.
resolution_rate
[
i
])
//
2
*
2
,
int
(
laten
t_w
*
self
.
config
[
"
resolution_rate
"
]
[
i
])
//
2
*
2
,
)
)
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
lat
_h
,
lat_w
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
lat
ent_h_tmp
,
latent_w_tmp
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
self
.
config
.
lat_h
,
self
.
config
.
la
t_w
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
laten
t_w
))
return
vae_encode_out_list
return
vae_encode_out_list
,
latent_shape
else
:
else
:
if
last_frame
is
not
None
:
if
last_frame
is
not
None
:
first_frame_size
=
first_frame
.
shape
[
2
:]
first_frame_size
=
first_frame
.
shape
[
2
:]
...
@@ -282,16 +284,15 @@ class WanRunner(DefaultRunner):
...
@@ -282,16 +284,15 @@ class WanRunner(DefaultRunner):
round
(
last_frame_size
[
1
]
*
last_frame_resize_ratio
),
round
(
last_frame_size
[
1
]
*
last_frame_resize_ratio
),
]
]
last_frame
=
TF
.
center_crop
(
last_frame
,
last_frame_size
)
last_frame
=
TF
.
center_crop
(
last_frame
,
last_frame_size
)
self
.
config
.
lat_h
,
self
.
config
.
lat_w
=
lat_h
,
lat_w
vae_encoder_out
=
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
latent_w
,
last_frame
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
first_frame
,
lat_h
,
lat_w
,
last_frame
)
return
vae_encoder_out
,
latent_shape
return
vae_encoder_out
def
get_vae_encoder_output
(
self
,
first_frame
,
lat_h
,
lat_w
,
last_frame
=
None
):
def
get_vae_encoder_output
(
self
,
first_frame
,
lat_h
,
lat_w
,
last_frame
=
None
):
h
=
lat_h
*
self
.
config
.
vae_stride
[
1
]
h
=
lat_h
*
self
.
config
[
"
vae_stride
"
]
[
1
]
w
=
lat_w
*
self
.
config
.
vae_stride
[
2
]
w
=
lat_w
*
self
.
config
[
"
vae_stride
"
]
[
2
]
msk
=
torch
.
ones
(
msk
=
torch
.
ones
(
1
,
1
,
self
.
config
.
target_video_length
,
self
.
config
[
"
target_video_length
"
]
,
lat_h
,
lat_h
,
lat_w
,
lat_w
,
device
=
torch
.
device
(
"cuda"
),
device
=
torch
.
device
(
"cuda"
),
...
@@ -312,7 +313,7 @@ class WanRunner(DefaultRunner):
...
@@ -312,7 +313,7 @@ class WanRunner(DefaultRunner):
vae_input
=
torch
.
concat
(
vae_input
=
torch
.
concat
(
[
[
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
2
,
h
,
w
),
torch
.
zeros
(
3
,
self
.
config
[
"
target_video_length
"
]
-
2
,
h
,
w
),
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
],
dim
=
1
,
dim
=
1
,
...
@@ -321,7 +322,7 @@ class WanRunner(DefaultRunner):
...
@@ -321,7 +322,7 @@ class WanRunner(DefaultRunner):
vae_input
=
torch
.
concat
(
vae_input
=
torch
.
concat
(
[
[
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
1
,
h
,
w
),
torch
.
zeros
(
3
,
self
.
config
[
"
target_video_length
"
]
-
1
,
h
,
w
),
],
],
dim
=
1
,
dim
=
1
,
).
cuda
()
).
cuda
()
...
@@ -345,32 +346,23 @@ class WanRunner(DefaultRunner):
...
@@ -345,32 +346,23 @@ class WanRunner(DefaultRunner):
"image_encoder_output"
:
image_encoder_output
,
"image_encoder_output"
:
image_encoder_output
,
}
}
def
set_target_shape
(
self
):
def
get_latent_shape_with_lat_hw
(
self
,
latent_h
,
latent_w
):
num_channels_latents
=
self
.
config
.
get
(
"num_channels_latents"
,
16
)
latent_shape
=
[
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
self
.
config
.
get
(
"num_channels_latents"
,
16
),
self
.
config
.
target_shape
=
(
(
self
.
config
[
"target_video_length"
]
-
1
)
//
self
.
config
[
"vae_stride"
][
0
]
+
1
,
num_channels_latents
,
latent_h
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
latent_w
,
self
.
config
.
lat_h
,
]
self
.
config
.
lat_w
,
return
latent_shape
)
elif
self
.
config
.
task
==
"t2v"
:
def
get_latent_shape_with_target_hw
(
self
,
target_h
,
target_w
):
self
.
config
.
target_shape
=
(
latent_shape
=
[
num_channels_latents
,
self
.
config
.
get
(
"num_channels_latents"
,
16
),
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
(
self
.
config
[
"target_video_length"
]
-
1
)
//
self
.
config
[
"vae_stride"
][
0
]
+
1
,
int
(
self
.
config
.
target_height
)
//
self
.
config
.
vae_stride
[
1
],
int
(
target_h
)
//
self
.
config
[
"vae_stride"
][
1
],
int
(
self
.
config
.
target_width
)
//
self
.
config
.
vae_stride
[
2
],
int
(
target_w
)
//
self
.
config
[
"vae_stride"
][
2
],
)
]
return
latent_shape
def
save_video_func
(
self
,
images
):
cache_video
(
tensor
=
images
,
save_file
=
self
.
config
.
save_video_path
,
fps
=
self
.
config
.
get
(
"fps"
,
16
),
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
),
)
class
MultiModelStruct
:
class
MultiModelStruct
:
...
@@ -400,7 +392,7 @@ class MultiModelStruct:
...
@@ -400,7 +392,7 @@ class MultiModelStruct:
def
get_current_model_index
(
self
):
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
>=
self
.
boundary_timestep
:
if
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
>=
self
.
boundary_timestep
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
0
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
0
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
0
)
self
.
to_cuda
(
model_index
=
0
)
...
@@ -410,7 +402,7 @@ class MultiModelStruct:
...
@@ -410,7 +402,7 @@ class MultiModelStruct:
self
.
cur_model_index
=
0
self
.
cur_model_index
=
0
else
:
else
:
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
1
]
self
.
scheduler
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
[
1
]
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
config
.
get
(
"cpu_offload"
,
False
)
and
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
==
"model"
:
if
self
.
cur_model_index
==
-
1
:
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
1
)
self
.
to_cuda
(
model_index
=
1
)
...
@@ -434,20 +426,20 @@ class Wan22MoeRunner(WanRunner):
...
@@ -434,20 +426,20 @@ class Wan22MoeRunner(WanRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
WanModel
(
high_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"high_noise_model"
),
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
low_noise_model
=
WanModel
(
low_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
"low_noise_model"
),
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"
lora_configs
"
]
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
for
lora_config
in
self
.
config
.
lora_configs
:
for
lora_config
in
self
.
config
[
"
lora_configs
"
]
:
lora_path
=
lora_config
[
"path"
]
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
base_name
=
os
.
path
.
basename
(
lora_path
)
base_name
=
os
.
path
.
basename
(
lora_path
)
...
@@ -464,7 +456,7 @@ class Wan22MoeRunner(WanRunner):
...
@@ -464,7 +456,7 @@ class Wan22MoeRunner(WanRunner):
else
:
else
:
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"
boundary
"
]
)
@
RUNNER_REGISTER
(
"wan2.2"
)
@
RUNNER_REGISTER
(
"wan2.2"
)
...
...
lightx2v/models/runners/wan/wan_sf_runner.py
View file @
04812de2
...
@@ -9,11 +9,10 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
...
@@ -9,11 +9,10 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from
lightx2v.models.schedulers.wan.self_forcing.scheduler
import
WanSFScheduler
from
lightx2v.models.schedulers.wan.self_forcing.scheduler
import
WanSFScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_sf
import
WanSFVAE
from
lightx2v.models.video_encoders.hf.wan.vae_sf
import
WanSFVAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
torch
.
manual_seed
(
42
)
@
RUNNER_REGISTER
(
"wan2.1_sf"
)
@
RUNNER_REGISTER
(
"wan2.1_sf"
)
class
WanSFRunner
(
WanRunner
):
class
WanSFRunner
(
WanRunner
):
...
@@ -59,40 +58,37 @@ class WanSFRunner(WanRunner):
...
@@ -59,40 +58,37 @@ class WanSFRunner(WanRunner):
gc
.
collect
()
gc
.
collect
()
return
images
return
images
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
init_run
(
self
):
def
run_main
(
self
,
total_steps
=
None
):
super
().
init_run
()
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
@
ProfilingContext4DebugL1
(
"End run segment"
)
self
.
model
.
select_graph_for_compile
()
def
end_run_segment
(
self
,
segment_idx
=
None
):
with
ProfilingContext4DebugL1
(
"step_pre_in_rerun"
):
total_blocks
=
self
.
scheduler
.
num_blocks
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
segment_idx
,
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
,
is_rerun
=
True
)
gen_videos
=
[]
with
ProfilingContext4DebugL1
(
"🚀 infer_main_in_rerun"
):
for
seg_index
in
range
(
self
.
video_segment_num
):
self
.
model
.
infer
(
self
.
inputs
)
logger
.
info
(
f
"==> segment_index:
{
seg_index
+
1
}
/
{
total_blocks
}
"
)
self
.
gen_video_final
=
torch
.
cat
([
self
.
gen_video_final
,
self
.
gen_video
],
dim
=
0
)
if
self
.
gen_video_final
is
not
None
else
self
.
gen_video
total_steps
=
len
(
self
.
scheduler
.
denoising_step_list
)
@
peak_memory_decorator
for
step_index
in
range
(
total_steps
):
def
run_segment
(
self
,
total_steps
=
None
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
with
ProfilingContext4DebugL1
(
"step_pre"
):
for
step_index
in
range
(
total_steps
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
seg_index
,
step_index
=
step_index
,
is_rerun
=
False
)
# only for single segment, check stop signal every step
if
self
.
video_segment_num
==
1
:
with
ProfilingContext4DebugL1
(
"🚀 infer_main"
):
self
.
check_stop
()
self
.
model
.
infer
(
self
.
inputs
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4DebugL1
(
"step_post"
):
with
ProfilingContext4DebugL1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
self
.
segment_idx
,
step_index
=
step_index
,
is_rerun
=
False
)
latents
=
self
.
model
.
scheduler
.
stream_output
with
ProfilingContext4DebugL1
(
"🚀 infer_main"
):
gen_videos
.
append
(
self
.
run_vae_decoder
(
latents
))
# rerun with timestep zero to update KV cache using clean context
with
ProfilingContext4DebugL1
(
"step_pre_in_rerun"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
seg_index
,
step_index
=
step_index
,
is_rerun
=
True
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main_in_rerun"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
self
.
gen_video
=
torch
.
cat
(
gen_videos
,
dim
=
0
)
with
ProfilingContext4DebugL1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
self
.
progress_callback
(((
step_index
+
1
)
/
total_steps
)
*
100
,
100
)
self
.
end_run
()
return
self
.
model
.
scheduler
.
stream_output
lightx2v/models/runners/wan/wan_vace_runner.py
View file @
04812de2
...
@@ -154,10 +154,10 @@ class WanVaceRunner(WanRunner):
...
@@ -154,10 +154,10 @@ class WanVaceRunner(WanRunner):
return
[
torch
.
cat
([
zz
,
mm
],
dim
=
0
)
for
zz
,
mm
in
zip
(
cat_latents
,
result_masks
)]
return
[
torch
.
cat
([
zz
,
mm
],
dim
=
0
)
for
zz
,
mm
in
zip
(
cat_latents
,
result_masks
)]
def
set_
targe
t_shape
(
self
):
def
set_
input_info_laten
t_shape
(
self
):
targe
t_shape
=
self
.
latent_shape
laten
t_shape
=
self
.
latent_shape
targe
t_shape
[
0
]
=
int
(
targe
t_shape
[
0
]
/
2
)
laten
t_shape
[
0
]
=
int
(
laten
t_shape
[
0
]
/
2
)
self
.
config
.
target_shape
=
targe
t_shape
return
laten
t_shape
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
...
...
lightx2v/models/schedulers/scheduler.py
View file @
04812de2
...
@@ -6,8 +6,8 @@ class BaseScheduler:
...
@@ -6,8 +6,8 @@ class BaseScheduler:
self
.
config
=
config
self
.
config
=
config
self
.
latents
=
None
self
.
latents
=
None
self
.
step_index
=
0
self
.
step_index
=
0
self
.
infer_steps
=
config
.
infer_steps
self
.
infer_steps
=
config
[
"
infer_steps
"
]
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
caching_records
=
[
True
]
*
config
[
"
infer_steps
"
]
self
.
flag_df
=
False
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
transformer_infer
=
None
self
.
infer_condition
=
True
# cfg status
self
.
infer_condition
=
True
# cfg status
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
04812de2
...
@@ -13,8 +13,8 @@ class EulerScheduler(WanScheduler):
...
@@ -13,8 +13,8 @@ class EulerScheduler(WanScheduler):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
if
self
.
config
.
parallel
:
if
self
.
config
[
"
parallel
"
]
:
self
.
sp_size
=
self
.
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
self
.
sp_size
=
self
.
config
[
"
parallel
"
]
.
get
(
"seq_p_size"
,
1
)
else
:
else
:
self
.
sp_size
=
1
self
.
sp_size
=
1
...
@@ -33,11 +33,11 @@ class EulerScheduler(WanScheduler):
...
@@ -33,11 +33,11 @@ class EulerScheduler(WanScheduler):
if
self
.
audio_adapter
.
cpu_offload
:
if
self
.
audio_adapter
.
cpu_offload
:
self
.
audio_adapter
.
time_embedding
.
to
(
"cpu"
)
self
.
audio_adapter
.
time_embedding
.
to
(
"cpu"
)
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
==
"wan2.2_audio"
:
_
,
lat_f
,
lat_h
,
lat_w
=
self
.
latents
.
shape
_
,
lat_f
,
lat_h
,
lat_w
=
self
.
latents
.
shape
F
=
(
lat_f
-
1
)
*
self
.
config
.
vae_stride
[
0
]
+
1
F
=
(
lat_f
-
1
)
*
self
.
config
[
"
vae_stride
"
]
[
0
]
+
1
per_latent_token_len
=
lat_h
*
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
per_latent_token_len
=
lat_h
*
lat_w
//
(
self
.
config
[
"
patch_size
"
]
[
1
]
*
self
.
config
[
"
patch_size
"
]
[
2
])
max_seq_len
=
((
F
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
per_latent_token_len
max_seq_len
=
((
F
-
1
)
//
self
.
config
[
"
vae_stride
"
]
[
0
]
+
1
)
*
per_latent_token_len
max_seq_len
=
int
(
math
.
ceil
(
max_seq_len
/
self
.
sp_size
))
*
self
.
sp_size
max_seq_len
=
int
(
math
.
ceil
(
max_seq_len
/
self
.
sp_size
))
*
self
.
sp_size
temp_ts
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
temp_ts
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
...
@@ -55,13 +55,13 @@ class EulerScheduler(WanScheduler):
...
@@ -55,13 +55,13 @@ class EulerScheduler(WanScheduler):
dim
=
1
,
dim
=
1
,
)
)
def
prepare_latents
(
self
,
targe
t_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
laten
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
targe
t_shape
[
0
],
laten
t_shape
[
0
],
targe
t_shape
[
1
],
laten
t_shape
[
1
],
targe
t_shape
[
2
],
laten
t_shape
[
2
],
targe
t_shape
[
3
],
laten
t_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
...
@@ -71,8 +71,8 @@ class EulerScheduler(WanScheduler):
...
@@ -71,8 +71,8 @@ class EulerScheduler(WanScheduler):
if
self
.
prev_latents
is
not
None
:
if
self
.
prev_latents
is
not
None
:
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
prev_latents
+
self
.
mask
*
self
.
latents
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
prev_latents
+
self
.
mask
*
self
.
latents
def
prepare
(
self
,
previmg
_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image
_encoder_output
=
None
):
self
.
prepare_latents
(
se
lf
.
config
.
targe
t_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
se
ed
,
laten
t_shape
,
dtype
=
torch
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
...
@@ -93,11 +93,11 @@ class EulerScheduler(WanScheduler):
...
@@ -93,11 +93,11 @@ class EulerScheduler(WanScheduler):
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
and
self
.
prev_latents
is
not
None
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
and
self
.
prev_latents
is
not
None
:
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
prev_latents
+
self
.
mask
*
self
.
latents
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
prev_latents
+
self
.
mask
*
self
.
latents
def
reset
(
self
,
previmg
_encoder_output
=
None
):
def
reset
(
self
,
seed
,
latent_shape
,
image
_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
self
.
prev_latents
=
previmg
_encoder_output
[
"prev_latents"
]
self
.
prev_latents
=
image
_encoder_output
[
"prev_latents"
]
self
.
prev_len
=
previmg
_encoder_output
[
"prev_len"
]
self
.
prev_len
=
image
_encoder_output
[
"prev_len"
]
self
.
prepare_latents
(
se
lf
.
config
.
targe
t_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
se
ed
,
laten
t_shape
,
dtype
=
torch
.
float32
)
def
unsqueeze_to_ndim
(
self
,
in_tensor
,
tgt_n_dim
):
def
unsqueeze_to_ndim
(
self
,
in_tensor
,
tgt_n_dim
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
if
in_tensor
.
ndim
>
tgt_n_dim
:
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
04812de2
...
@@ -19,16 +19,16 @@ class WanScheduler4ChangingResolution:
...
@@ -19,16 +19,16 @@ class WanScheduler4ChangingResolution:
config
[
"changing_resolution_steps"
]
=
[
config
.
infer_steps
//
2
]
config
[
"changing_resolution_steps"
]
=
[
config
.
infer_steps
//
2
]
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
targe
t_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
laten
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents_list
=
[]
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
self
.
latents_list
.
append
(
torch
.
randn
(
torch
.
randn
(
targe
t_shape
[
0
],
laten
t_shape
[
0
],
targe
t_shape
[
1
],
laten
t_shape
[
1
],
int
(
targe
t_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
laten
t_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
targe
t_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
laten
t_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
...
@@ -38,10 +38,10 @@ class WanScheduler4ChangingResolution:
...
@@ -38,10 +38,10 @@ class WanScheduler4ChangingResolution:
# add original resolution latents
# add original resolution latents
self
.
latents_list
.
append
(
self
.
latents_list
.
append
(
torch
.
randn
(
torch
.
randn
(
targe
t_shape
[
0
],
laten
t_shape
[
0
],
targe
t_shape
[
1
],
laten
t_shape
[
1
],
targe
t_shape
[
2
],
laten
t_shape
[
2
],
targe
t_shape
[
3
],
laten
t_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
04812de2
import
gc
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -12,22 +11,22 @@ class WanScheduler(BaseScheduler):
...
@@ -12,22 +11,22 @@ class WanScheduler(BaseScheduler):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
infer_steps
=
self
.
config
[
"
infer_steps
"
]
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
target_video_length
=
self
.
config
[
"
target_video_length
"
]
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
sample_shift
=
self
.
config
[
"
sample_shift
"
]
self
.
shift
=
1
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
num_train_timesteps
=
1000
self
.
disable_corrector
=
[]
self
.
disable_corrector
=
[]
self
.
solver_order
=
2
self
.
solver_order
=
2
self
.
noise_pred
=
None
self
.
noise_pred
=
None
self
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
self
.
sample_guide_scale
=
self
.
config
[
"
sample_guide_scale
"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
.
infer_steps
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"
infer_steps
"
]
def
prepare
(
self
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
vae_encoder_out
=
image_encoder_output
[
"vae_encoder_out"
]
self
.
vae_encoder_out
=
image_encoder_output
[
"vae_encoder_out"
]
self
.
prepare_latents
(
se
lf
.
config
.
targe
t_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
se
ed
,
laten
t_shape
,
dtype
=
torch
.
float32
)
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
1.0
-
alphas
...
@@ -48,18 +47,18 @@ class WanScheduler(BaseScheduler):
...
@@ -48,18 +47,18 @@ class WanScheduler(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
targe
t_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
laten
t_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
targe
t_shape
[
0
],
laten
t_shape
[
0
],
targe
t_shape
[
1
],
laten
t_shape
[
1
],
targe
t_shape
[
2
],
laten
t_shape
[
2
],
targe
t_shape
[
3
],
laten
t_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
mask
=
masks_like
(
self
.
latents
,
zero
=
True
)
self
.
mask
=
masks_like
(
self
.
latents
,
zero
=
True
)
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
...
@@ -117,7 +116,7 @@ class WanScheduler(BaseScheduler):
...
@@ -117,7 +116,7 @@ class WanScheduler(BaseScheduler):
x0_pred
=
sample
-
sigma_t
*
model_output
x0_pred
=
sample
-
sigma_t
*
model_output
return
x0_pred
return
x0_pred
def
reset
(
self
,
step_index
=
None
):
def
reset
(
self
,
seed
,
latent_shape
,
step_index
=
None
):
if
step_index
is
not
None
:
if
step_index
is
not
None
:
self
.
step_index
=
step_index
self
.
step_index
=
step_index
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
...
@@ -126,9 +125,7 @@ class WanScheduler(BaseScheduler):
...
@@ -126,9 +125,7 @@ class WanScheduler(BaseScheduler):
self
.
noise_pred
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
multistep_uni_p_bh_update
(
def
multistep_uni_p_bh_update
(
self
,
self
,
...
@@ -325,7 +322,7 @@ class WanScheduler(BaseScheduler):
...
@@ -325,7 +322,7 @@ class WanScheduler(BaseScheduler):
def
step_pre
(
self
,
step_index
):
def
step_pre
(
self
,
step_index
):
super
().
step_pre
(
step_index
)
super
().
step_pre
(
step_index
)
self
.
timestep_input
=
torch
.
stack
([
self
.
timesteps
[
self
.
step_index
]])
self
.
timestep_input
=
torch
.
stack
([
self
.
timesteps
[
self
.
step_index
]])
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
timestep_input
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
self
.
timestep_input
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
def
step_post
(
self
):
def
step_post
(
self
):
...
@@ -367,5 +364,5 @@ class WanScheduler(BaseScheduler):
...
@@ -367,5 +364,5 @@ class WanScheduler(BaseScheduler):
self
.
lower_order_nums
+=
1
self
.
lower_order_nums
+=
1
self
.
latents
=
prev_sample
self
.
latents
=
prev_sample
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
04812de2
...
@@ -9,24 +9,25 @@ class WanSFScheduler(WanScheduler):
...
@@ -9,24 +9,25 @@ class WanSFScheduler(WanScheduler):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
.
sf_config
.
num_frame_per_block
self
.
num_frame_per_block
=
self
.
config
[
"
sf_config
"
][
"
num_frame_per_block
"
]
self
.
num_output_frames
=
self
.
config
.
sf_config
.
num_output_frames
self
.
num_output_frames
=
self
.
config
[
"
sf_config
"
][
"
num_output_frames
"
]
self
.
num_blocks
=
self
.
num_output_frames
//
self
.
num_frame_per_block
self
.
num_blocks
=
self
.
num_output_frames
//
self
.
num_frame_per_block
self
.
denoising_step_list
=
self
.
config
.
sf_config
.
denoising_step_list
self
.
denoising_step_list
=
self
.
config
[
"sf_config"
][
"denoising_step_list"
]
self
.
infer_steps
=
len
(
self
.
denoising_step_list
)
self
.
all_num_frames
=
[
self
.
num_frame_per_block
]
*
self
.
num_blocks
self
.
all_num_frames
=
[
self
.
num_frame_per_block
]
*
self
.
num_blocks
self
.
num_input_frames
=
0
self
.
num_input_frames
=
0
self
.
denoising_strength
=
1.0
self
.
denoising_strength
=
1.0
self
.
sigma_max
=
1.0
self
.
sigma_max
=
1.0
self
.
sigma_min
=
0
self
.
sigma_min
=
0
self
.
sf_shift
=
self
.
config
.
sf_config
.
shift
self
.
sf_shift
=
self
.
config
[
"
sf_config
"
][
"
shift
"
]
self
.
inverse_timesteps
=
False
self
.
inverse_timesteps
=
False
self
.
extra_one_step
=
True
self
.
extra_one_step
=
True
self
.
reverse_sigmas
=
False
self
.
reverse_sigmas
=
False
self
.
num_inference_steps
=
self
.
config
.
sf_config
.
num_inference_steps
self
.
num_inference_steps
=
self
.
config
[
"
sf_config
"
][
"
num_inference_steps
"
]
self
.
context_noise
=
0
self
.
context_noise
=
0
def
prepare
(
self
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
self
.
config
.
targe
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
self
.
latents
=
torch
.
randn
(
laten
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
timesteps
=
[]
timesteps
=
[]
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
...
@@ -39,7 +40,7 @@ class WanSFScheduler(WanScheduler):
...
@@ -39,7 +40,7 @@ class WanSFScheduler(WanScheduler):
timesteps
.
append
(
frame_steps
)
timesteps
.
append
(
frame_steps
)
self
.
timesteps
=
timesteps
self
.
timesteps
=
timesteps
self
.
noise_pred
=
torch
.
zeros
(
self
.
config
.
targe
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
self
.
noise_pred
=
torch
.
zeros
(
laten
t_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
if
self
.
extra_one_step
:
if
self
.
extra_one_step
:
...
@@ -91,7 +92,7 @@ class WanSFScheduler(WanScheduler):
...
@@ -91,7 +92,7 @@ class WanSFScheduler(WanScheduler):
x0_pred
=
x0_pred
.
to
(
original_dtype
)
x0_pred
=
x0_pred
.
to
(
original_dtype
)
# add noise
# add noise
if
self
.
step_index
<
len
(
self
.
denoising_step_list
)
-
1
:
if
self
.
step_index
<
self
.
infer_steps
-
1
:
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment