Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
cf04772a
Commit
cf04772a
authored
Aug 26, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Aug 26, 2025
Browse files
Fix wan22 ti2v vae & update audio profiler (#246)
parent
cb83f2f8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
15 additions
and
135 deletions
+15
-135
lightx2v/infer.py
lightx2v/infer.py
+1
-1
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+0
-13
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+1
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+5
-5
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+7
-92
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-0
scripts/wan22/run_wan22_moe_i2v_audio.sh
scripts/wan22/run_wan22_moe_i2v_audio.sh
+0
-24
No files found.
lightx2v/infer.py
View file @
cf04772a
...
@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner #
...
@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner #
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.qwen_image.qwen_image_runner
import
QwenImageRunner
# noqa: F401
from
lightx2v.models.runners.qwen_image.qwen_image_runner
import
QwenImageRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
Wan22MoeAudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
cf04772a
import
glob
import
os
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.pre_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.audio.pre_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.audio.transformer_infer
import
WanAudioTransformerInfer
from
lightx2v.models.networks.wan.infer.audio.transformer_infer
import
WanAudioTransformerInfer
...
@@ -29,13 +26,3 @@ class WanAudioModel(WanModel):
...
@@ -29,13 +26,3 @@ class WanAudioModel(WanModel):
def
set_audio_adapter
(
self
,
audio_adapter
):
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
self
.
audio_adapter
=
audio_adapter
self
.
transformer_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
transformer_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
lightx2v/models/runners/base_runner.py
View file @
cf04772a
...
@@ -11,6 +11,7 @@ class BaseRunner(ABC):
...
@@ -11,6 +11,7 @@ class BaseRunner(ABC):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
self
.
vae_encoder_need_img_original
=
False
def
load_transformer
(
self
):
def
load_transformer
(
self
):
"""Load transformer model
"""Load transformer model
...
...
lightx2v/models/runners/default_runner.py
View file @
cf04772a
...
@@ -145,16 +145,16 @@ class DefaultRunner(BaseRunner):
...
@@ -145,16 +145,16 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
gc
.
collect
()
def
read_image_input
(
self
,
img_path
):
def
read_image_input
(
self
,
img_path
):
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
_ori
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
img
=
TF
.
to_tensor
(
img
_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
return
img
return
img
,
img_ori
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
(
"Run Encoders"
)
def
_run_input_encoder_local_i2v
(
self
):
def
_run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img
)
vae_encode_out
=
self
.
run_vae_encoder
(
img_ori
if
self
.
vae_encoder_need_img_original
else
img
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
img
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
img
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
cf04772a
...
@@ -14,19 +14,17 @@ from einops import rearrange
...
@@ -14,19 +14,17 @@ from einops import rearrange
from
loguru
import
logger
from
loguru
import
logger
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
resize
from
torchvision.transforms.functional
import
resize
from
transformers
import
AutoFeatureExtractor
from
lightx2v.models.input_encoders.hf.seko_audio.audio_adapter
import
AudioAdapter
from
lightx2v.models.input_encoders.hf.seko_audio.audio_adapter
import
AudioAdapter
from
lightx2v.models.input_encoders.hf.seko_audio.audio_encoder
import
SekoAudioEncoderModel
from
lightx2v.models.input_encoders.hf.seko_audio.audio_encoder
import
SekoAudioEncoderModel
from
lightx2v.models.networks.wan.audio_model
import
Wan22MoeAudioModel
,
WanAudioModel
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
load_weights
,
save_to_video
,
vae_to_comfyui_image
def
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
sp_size
):
def
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
sp_size
):
...
@@ -398,6 +396,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -398,6 +396,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
cut_audio_list
=
[]
self
.
cut_audio_list
=
[]
self
.
prev_video
=
None
self
.
prev_video
=
None
@
ProfilingContext4Debug
(
"Init run segment"
)
def
init_run_segment
(
self
,
segment_idx
):
def
init_run_segment
(
self
,
segment_idx
):
self
.
segment_idx
=
segment_idx
self
.
segment_idx
=
segment_idx
...
@@ -421,6 +420,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -421,6 +420,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
inputs
[
"previmg_encoder_output"
]
=
self
.
prepare_prev_latents
(
self
.
prev_video
,
prev_frame_length
=
5
)
self
.
inputs
[
"previmg_encoder_output"
]
=
self
.
prepare_prev_latents
(
self
.
prev_video
,
prev_frame_length
=
5
)
@
ProfilingContext4Debug
(
"End run segment"
)
def
end_run_segment
(
self
):
def
end_run_segment
(
self
):
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
...
@@ -446,6 +446,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -446,6 +446,7 @@ class WanAudioRunner(WanRunner): # type:ignore
del
self
.
gen_video
del
self
.
gen_video
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
ProfilingContext4Debug
(
"Process after vae decoder"
)
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
# Merge results
# Merge results
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
...
@@ -599,89 +600,3 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -599,89 +600,3 @@ class WanAudioRunner(WanRunner): # type:ignore
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
return
ret
@
RUNNER_REGISTER
(
"wan2.2_audio"
)
class
Wan22AudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
vae_decoder
=
Wan2_2_VAE
(
**
vae_config
)
return
vae_decoder
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
if
self
.
config
.
task
!=
"i2v"
:
return
None
else
:
return
Wan2_2_VAE
(
**
vae_config
)
def
load_vae
(
self
):
vae_encoder
=
self
.
load_vae_encoder
()
vae_decoder
=
self
.
load_vae_decoder
()
return
vae_encoder
,
vae_decoder
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
Wan22MoeAudioModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
low_noise_model
=
Wan22MoeAudioModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
if
lora_config
.
name
==
"high_noise_model"
:
lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"
{
lora_config
.
name
}
Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
if
lora_config
.
name
==
"low_noise_model"
:
lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"
{
lora_config
.
name
}
Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
# XXX: trick
self
.
_audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary
)
lightx2v/models/runners/wan/wan_runner.py
View file @
cf04772a
...
@@ -430,6 +430,7 @@ class Wan22MoeRunner(WanRunner):
...
@@ -430,6 +430,7 @@ class Wan22MoeRunner(WanRunner):
class
Wan22DenseRunner
(
WanRunner
):
class
Wan22DenseRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
vae_encoder_need_img_original
=
True
def
load_vae_decoder
(
self
):
def
load_vae_decoder
(
self
):
# offload config
# offload config
...
...
scripts/wan22/run_wan22_moe_i2v_audio.sh
deleted
100755 → 0
View file @
cb83f2f8
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
0
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
export
TORCH_CUDA_ARCH_LIST
=
"9.0"
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
python
-m
lightx2v.infer
\
--model_cls
wan2.2_moe_audio
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/wan22/wan_moe_i2v_audio.json
\
--prompt
"The video features a old lady is saying something and knitting a sweater."
\
--negative_prompt
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan22_moe_i2v_audio.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment