Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
783b3a72
Commit
783b3a72
authored
Aug 07, 2025
by
gushiqiao
Committed by
GitHub
Aug 07, 2025
Browse files
Fix audio model compile and offload bugs
Dev gsq
parents
9067043e
92f067f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
15 deletions
+85
-15
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+25
-12
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+60
-3
No files found.
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
783b3a72
...
...
@@ -85,6 +85,13 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
...
...
@@ -109,6 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
context
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
...
...
@@ -137,6 +146,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
context
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
...
...
@@ -179,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
...
...
@@ -239,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
...
...
@@ -275,6 +290,14 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
@
torch
.
_dynamo
.
disable
def
_apply_audio_dit
(
self
,
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
):
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
...
...
@@ -287,12 +310,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
context
,
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
if
audio_dit_blocks
is
not
None
and
len
(
audio_dit_blocks
)
>
0
:
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
...
...
@@ -328,13 +348,6 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
infer_self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
weights
.
smooth_norm1_weight
.
tensor
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
783b3a72
...
...
@@ -3,7 +3,7 @@ import os
import
subprocess
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
...
...
@@ -302,7 +302,7 @@ class VideoGenerator:
return
mask
.
transpose
(
0
,
1
)
@
torch
.
no_grad
()
def
generate_segment
(
self
,
inputs
:
Dict
[
str
,
Any
],
audio_features
:
torch
.
Tensor
,
prev_video
:
Optional
[
torch
.
Tensor
]
=
None
,
prev_frame_length
:
int
=
5
,
segment_idx
:
int
=
0
)
->
torch
.
Tensor
:
def
generate_segment
(
self
,
inputs
,
audio_features
,
prev_video
=
None
,
prev_frame_length
=
5
,
segment_idx
=
0
,
total_steps
=
None
)
:
"""Generate video segment"""
# Update inputs with audio features
inputs
[
"audio_encoder_output"
]
=
audio_features
...
...
@@ -352,7 +352,8 @@ class VideoGenerator:
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
# Run inference loop
total_steps
=
self
.
model
.
scheduler
.
infer_steps
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
total_steps
}
"
)
...
...
@@ -694,6 +695,62 @@ class WanAudioRunner(WanRunner): # type:ignore
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
def
run_step
(
self
):
"""Optimized pipeline with modular components"""
self
.
initialize
()
assert
self
.
_audio_processor
is
not
None
assert
self
.
_audio_preprocess
is
not
None
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
with
memory_efficient_inference
():
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
prepare_inputs
()
# Re-initialize scheduler after image encoding sets correct dimensions
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# Re-create video generator with updated model/scheduler
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
# Process audio
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
self
.
_video_generator
.
total_segments
=
len
(
audio_segments
)
# Generate video segments
prev_video
=
None
torch
.
manual_seed
(
self
.
config
.
seed
)
# Process audio features
audio_features
=
self
.
_audio_preprocess
(
audio_segments
[
0
].
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
# Generate video segment
with
memory_efficient_inference
():
self
.
_video_generator
.
generate_segment
(
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
audio_features
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
segment_idx
=
0
,
total_steps
=
1
,
)
# Final cleanup
self
.
end_run
()
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment