Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
75eac23c
Commit
75eac23c
authored
Jul 23, 2025
by
wangshankun
Browse files
Add:Audio CM Scheduler
parent
d86b6917
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
11 deletions
+6
-11
configs/audio_driven/wan_i2v_audio.json
configs/audio_driven/wan_i2v_audio.json
+2
-3
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+0
-2
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+4
-4
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+0
-2
No files found.
configs/audio_driven/wan_i2v_audio.json
View file @
75eac23c
...
...
@@ -6,13 +6,12 @@
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"
radial
_attn"
,
"self_attn_1_type"
:
"
flash
_attn
3
"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_tiling_vae"
:
true
"cpu_offload"
:
false
}
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
75eac23c
...
...
@@ -24,8 +24,6 @@ class WanAudioPreInfer(WanPreInfer):
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
inputs
,
positive
):
ltnt_frames
=
self
.
scheduler
.
latents
.
size
(
1
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
75eac23c
...
...
@@ -19,7 +19,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
,
ConsistencyModelScheduler
from
loguru
import
logger
import
torch.distributed
as
dist
...
...
@@ -327,7 +327,7 @@ class WanAudioRunner(WanRunner):
super
().
__init__
(
config
)
def
init_scheduler
(
self
):
scheduler
=
EulerSchedulerTimestepFix
(
self
.
config
)
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
...
...
@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner):
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_noise_to_frames
(
last_frames
)
# mean:-3.0 std:0.5
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
...
...
@@ -583,7 +583,7 @@ class WanAudioRunner(WanRunner):
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
)
.
to
(
torch
.
float
)
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
75eac23c
...
...
@@ -34,7 +34,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
step_index
=
None
...
...
@@ -94,7 +93,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
class
ConsistencyModelScheduler
(
EulerSchedulerTimestepFix
):
def
step_post
(
self
):
logger
.
info
(
f
"Step index:
{
self
.
step_index
}
, self.timestep:
{
self
.
timesteps
[
self
.
step_index
]
}
"
)
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment