Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
00962c67
"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "be937d4a201a96a537d645992ebfbe21f70cc493"
Commit
00962c67
authored
Aug 07, 2025
by
gushiqiao
Browse files
Fix audio model compile and offload bugs
parent
4389450a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
90 additions
and
17 deletions
+90
-17
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+30
-14
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+59
-2
No files found.
lightx2v/models/networks/wan/audio_adapter.py
View file @
00962c67
...
@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module):
shift_scale_gate
=
torch
.
zeros
((
1
,
3
,
inner_dim
))
shift_scale_gate
=
torch
.
zeros
((
1
,
3
,
inner_dim
))
shift_scale_gate
[:,
2
]
=
1
shift_scale_gate
[:,
2
]
=
1
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
model_dim) latents (batchsize, length, model_dim)"""
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
00962c67
...
@@ -83,7 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -83,7 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
...
@@ -108,6 +115,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -108,6 +115,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
freqs
,
context
,
context
,
)
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
weights_stream_mgr
.
swap_weights
()
...
@@ -136,7 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -136,7 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
freqs
,
context
,
context
,
)
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
weights_stream_mgr
.
swap_weights
()
if
block_idx
==
self
.
blocks_num
-
1
:
if
block_idx
==
self
.
blocks_num
-
1
:
...
@@ -144,6 +155,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -144,6 +155,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -178,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -178,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif
cur_phase_idx
==
3
:
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
if
not
is_last_phase
:
...
@@ -238,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -238,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif
cur_phase_idx
==
3
:
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
...
@@ -274,6 +290,16 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -274,6 +290,16 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
return
freqs_cis
@
torch
.
_dynamo
.
disable
def
_apply_audio_dit
(
self
,
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
):
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
x
=
self
.
infer_block
(
...
@@ -286,12 +312,9 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -286,12 +312,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
freqs
,
context
,
context
,
)
)
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
if
audio_dit_blocks
is
not
None
and
len
(
audio_dit_blocks
)
>
0
:
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
...
@@ -327,13 +350,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -327,13 +350,6 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
infer_self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
weights
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
weights
.
smooth_norm1_weight
.
tensor
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
00962c67
...
@@ -302,7 +302,7 @@ class VideoGenerator:
...
@@ -302,7 +302,7 @@ class VideoGenerator:
return
mask
.
transpose
(
0
,
1
)
return
mask
.
transpose
(
0
,
1
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate_segment
(
self
,
inputs
:
Dict
[
str
,
Any
],
audio_features
:
torch
.
Tensor
,
prev_video
:
Optional
[
torch
.
Tensor
]
=
None
,
prev_frame_length
:
int
=
5
,
segment_idx
:
int
=
0
)
->
torch
.
Tensor
:
def
generate_segment
(
self
,
inputs
,
audio_features
,
prev_video
=
None
,
prev_frame_length
=
5
,
segment_idx
=
0
,
total_steps
=
None
)
:
"""Generate video segment"""
"""Generate video segment"""
# Update inputs with audio features
# Update inputs with audio features
inputs
[
"audio_encoder_output"
]
=
audio_features
inputs
[
"audio_encoder_output"
]
=
audio_features
...
@@ -352,7 +352,8 @@ class VideoGenerator:
...
@@ -352,7 +352,8 @@ class VideoGenerator:
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
# Run inference loop
# Run inference loop
total_steps
=
self
.
model
.
scheduler
.
infer_steps
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
total_steps
}
"
)
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
total_steps
}
"
)
...
@@ -686,6 +687,62 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -686,6 +687,62 @@ class WanAudioRunner(WanRunner): # type:ignore
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
return
ret
def
run_step
(
self
):
"""Optimized pipeline with modular components"""
self
.
initialize
()
assert
self
.
_audio_processor
is
not
None
assert
self
.
_audio_preprocess
is
not
None
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
with
memory_efficient_inference
():
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
prepare_inputs
()
# Re-initialize scheduler after image encoding sets correct dimensions
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# Re-create video generator with updated model/scheduler
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
# Process audio
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
self
.
_video_generator
.
total_segments
=
len
(
audio_segments
)
# Generate video segments
prev_video
=
None
torch
.
manual_seed
(
self
.
config
.
seed
)
# Process audio features
audio_features
=
self
.
_audio_preprocess
(
audio_segments
[
0
].
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
# Generate video segment
with
memory_efficient_inference
():
self
.
_video_generator
.
generate_segment
(
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
audio_features
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
segment_idx
=
0
,
total_steps
=
1
)
# Final cleanup
self
.
end_run
()
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment