Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
92f067f1
Commit
92f067f1
authored
Aug 07, 2025
by
gushiqiao
Browse files
Fix audio model compile and offload bugs
parent
00962c67
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
10 deletions
+7
-10
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+4
-7
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+2
-2
No files found.
lightx2v/models/networks/wan/audio_adapter.py
View file @
92f067f1
...
@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module):
shift_scale_gate
=
torch
.
zeros
((
1
,
3
,
inner_dim
))
shift_scale_gate
=
torch
.
zeros
((
1
,
3
,
inner_dim
))
shift_scale_gate
[:,
2
]
=
1
shift_scale_gate
[:,
2
]
=
1
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
model_dim) latents (batchsize, length, model_dim)"""
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
92f067f1
...
@@ -83,14 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -83,14 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
if
"audio"
in
self
.
config
.
get
(
"model_cls"
,
""
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
return
freqs_i
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
...
@@ -147,7 +147,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -147,7 +147,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
)
if
audio_dit_blocks
:
if
audio_dit_blocks
:
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
x
=
self
.
_apply_audio_dit
(
x
,
block_idx
,
grid_sizes
,
audio_dit_blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
weights_stream_mgr
.
swap_weights
()
if
block_idx
==
self
.
blocks_num
-
1
:
if
block_idx
==
self
.
blocks_num
-
1
:
...
@@ -155,7 +155,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -155,7 +155,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -295,9 +294,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -295,9 +294,7 @@ class WanTransformerInfer(BaseTransformerInfer):
for
ipa_out
in
audio_dit_blocks
:
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
92f067f1
...
@@ -3,7 +3,7 @@ import os
...
@@ -3,7 +3,7 @@ import os
import
subprocess
import
subprocess
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -738,7 +738,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -738,7 +738,7 @@ class WanAudioRunner(WanRunner): # type:ignore
prev_video
=
prev_video
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
prev_frame_length
=
5
,
segment_idx
=
0
,
segment_idx
=
0
,
total_steps
=
1
total_steps
=
1
,
)
)
# Final cleanup
# Final cleanup
self
.
end_run
()
self
.
end_run
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment