Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3488b187
"docs/basic_usage/offline_engine_api.ipynb" did not exist on "5f0e7de339765843b20872ad2a4e39a07a6d8cdf"
Commit
3488b187
authored
Aug 25, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Aug 25, 2025
Browse files
update audio pre_infer (#241)
parent
d8454a2b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
28 additions
and
86 deletions
+28
-86
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+0
-1
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+12
-56
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
...tx2v/models/networks/wan/infer/audio/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-18
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+3
-7
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+5
-3
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+6
-0
No files found.
lightx2v/models/networks/wan/audio_model.py
View file @
3488b187
...
@@ -28,7 +28,6 @@ class WanAudioModel(WanModel):
...
@@ -28,7 +28,6 @@ class WanAudioModel(WanModel):
def
set_audio_adapter
(
self
,
audio_adapter
):
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
self
.
audio_adapter
=
audio_adapter
self
.
pre_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
transformer_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
transformer_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
3488b187
import
math
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
...
@@ -35,36 +33,13 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -35,36 +33,13 @@ class WanAudioPreInfer(WanPreInfer):
else
:
else
:
self
.
sp_size
=
1
self
.
sp_size
=
1
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
def
infer
(
self
,
weights
,
inputs
):
def
infer
(
self
,
weights
,
inputs
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
hidden_states
=
torch
.
cat
([
self
.
scheduler
.
latents
,
prev_mask
,
prev_latents
],
dim
=
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
(
1.0
-
prev_mask
[
0
])
*
prev_latents
+
prev_mask
[
0
]
*
hidden_states
else
:
prev_latents
=
prev_latents
.
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
hidden_states
x
=
hidden_states
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
t
=
self
.
scheduler
.
timestep_input
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
_
,
lat_f
,
lat_h
,
lat_w
=
self
.
scheduler
.
latents
.
shape
F
=
(
lat_f
-
1
)
*
self
.
config
.
vae_stride
[
0
]
+
1
max_seq_len
=
((
F
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
lat_h
*
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
max_seq_len
=
int
(
math
.
ceil
(
max_seq_len
/
self
.
sp_size
))
*
self
.
sp_size
temp_ts
=
(
prev_mask
[
0
][
0
][:,
::
2
,
::
2
]
*
t
).
flatten
()
temp_ts
=
torch
.
cat
([
temp_ts
,
temp_ts
.
new_ones
(
max_seq_len
-
temp_ts
.
size
(
0
))
*
t
])
t
=
temp_ts
.
unsqueeze
(
0
)
t_emb
=
self
.
audio_adapter
.
time_embedding
(
t
).
unflatten
(
1
,
(
3
,
-
1
))
if
self
.
scheduler
.
infer_condition
:
if
self
.
scheduler
.
infer_condition
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
@@ -76,16 +51,16 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -76,16 +51,16 @@ class WanAudioPreInfer(WanPreInfer):
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
].
to
(
self
.
scheduler
.
latents
.
dtype
)
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
].
to
(
self
.
scheduler
.
latents
.
dtype
)
# batch_size = len(x)
# batch_size = len(x)
num_channels
,
_
,
height
,
width
=
x
.
shape
num_channels
,
_
,
height
,
width
=
x
.
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
zero_padding
=
torch
.
zeros
(
(
1
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
(
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
dtype
=
self
.
scheduler
.
latents
.
dtype
,
dtype
=
self
.
scheduler
.
latents
.
dtype
,
device
=
self
.
scheduler
.
latents
.
device
,
device
=
self
.
scheduler
.
latents
.
device
,
)
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
1
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
0
)
y
=
list
(
torch
.
unbind
(
ref_image_encoder
,
dim
=
0
))
# 第一个batch维度变成list
y
=
ref_image_encoder
# 第一个batch维度变成list
# embeddings
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
...
@@ -93,29 +68,10 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -93,29 +68,10 @@ class WanAudioPreInfer(WanPreInfer):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
().
unsqueeze
(
0
)
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
().
unsqueeze
(
0
)
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
ref_seq_lens
=
torch
.
tensor
([
u
.
size
(
0
)
for
u
in
y
],
dtype
=
torch
.
long
)
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
)
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
torch
.
stack
(
x
,
dim
=
0
)
seq_len
=
x
[
0
].
size
(
0
)
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
bt
=
t
.
size
(
0
)
ref_seq_len
=
ref_seq_lens
[
0
].
item
()
t
=
torch
.
cat
(
[
t
,
torch
.
zeros
(
(
1
,
ref_seq_len
),
dtype
=
t
.
dtype
,
device
=
t
.
device
,
),
],
dim
=
1
,
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
...
@@ -167,5 +123,5 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -167,5 +123,5 @@ class WanAudioPreInfer(WanPreInfer):
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
freqs
=
self
.
freqs
,
context
=
context
,
context
=
context
,
adapter_output
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
]
,
"t_emb"
:
t_emb
},
adapter_output
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
]},
)
)
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
View file @
3488b187
...
@@ -32,7 +32,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
...
@@ -32,7 +32,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
grid_sizes
=
pre_infer_out
.
grid_sizes
,
grid_sizes
=
pre_infer_out
.
grid_sizes
,
ca_block
=
self
.
audio_adapter
.
ca
[
self
.
block_idx
],
ca_block
=
self
.
audio_adapter
.
ca
[
self
.
block_idx
],
audio_encoder_output
=
pre_infer_out
.
adapter_output
[
"audio_encoder_output"
],
audio_encoder_output
=
pre_infer_out
.
adapter_output
[
"audio_encoder_output"
],
t_emb
=
pre_infer_out
.
adapter_output
[
"
t_emb
"
]
,
t_emb
=
self
.
scheduler
.
audio_adapter_
t_emb
,
weight
=
1.0
,
weight
=
1.0
,
seq_p_group
=
self
.
seq_p_group
,
seq_p_group
=
self
.
seq_p_group
,
)
)
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
3488b187
...
@@ -34,15 +34,7 @@ class WanPreInfer:
...
@@ -34,15 +34,7 @@ class WanPreInfer:
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
x
=
self
.
scheduler
.
latents
x
=
self
.
scheduler
.
latents
t
=
self
.
scheduler
.
timestep_input
if
self
.
scheduler
.
flag_df
:
t
=
self
.
scheduler
.
df_timesteps
[
self
.
scheduler
.
step_index
].
unsqueeze
(
0
)
assert
t
.
dim
()
==
2
# df推理模型timestep是二维
else
:
timestep
=
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
t
=
torch
.
stack
([
timestep
])
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
t
=
(
self
.
scheduler
.
mask
[
0
][:,
::
2
,
::
2
]
*
t
).
flatten
()
if
self
.
scheduler
.
infer_condition
:
if
self
.
scheduler
.
infer_condition
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
@@ -91,15 +83,6 @@ class WanPreInfer:
...
@@ -91,15 +83,6 @@ class WanPreInfer:
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
if
self
.
scheduler
.
flag_df
:
b
,
f
=
t
.
shape
assert
b
==
len
(
x
)
# batch_size == 1
embed
=
embed
.
view
(
b
,
f
,
1
,
1
,
self
.
dim
)
embed0
=
embed0
.
view
(
b
,
f
,
1
,
1
,
6
,
self
.
dim
)
embed
=
embed
.
repeat
(
1
,
1
,
grid_sizes
[
0
][
1
],
grid_sizes
[
0
][
2
],
1
).
flatten
(
1
,
3
)
embed0
=
embed0
.
repeat
(
1
,
1
,
grid_sizes
[
0
][
1
],
grid_sizes
[
0
][
2
],
1
,
1
).
flatten
(
1
,
3
)
embed0
=
embed0
.
transpose
(
1
,
2
).
contiguous
()
# text embeddings
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
3488b187
...
@@ -246,6 +246,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -246,6 +246,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
"""Initialize consistency model scheduler"""
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
scheduler
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
def
read_audio_input
(
self
):
def
read_audio_input
(
self
):
...
@@ -292,12 +293,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -292,12 +293,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
run_vae_encoder
(
self
,
img
):
def
run_vae_encoder
(
self
,
img
):
img
=
rearrange
(
img
,
"1 C H W -> 1 C 1 H W"
)
img
=
rearrange
(
img
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
img
.
to
(
torch
.
float
))
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
img
.
to
(
torch
.
float
))[
0
]
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
vae_encoder_out
=
vae_encoder_out
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
else
:
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
())
return
vae_encoder_out
return
vae_encoder_out
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
(
"Run Encoders"
)
...
@@ -351,7 +347,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -351,7 +347,7 @@ class WanAudioRunner(WanRunner): # type:ignore
frames_n
=
(
nframe
-
1
)
*
4
+
1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
)
.
unsqueeze
(
0
)
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
)
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
3488b187
...
@@ -12,10 +12,12 @@ class ConsistencyModelScheduler(WanScheduler):
...
@@ -12,10 +12,12 @@ class ConsistencyModelScheduler(WanScheduler):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
def
step_pre
(
self
,
step_index
):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
super
().
step_pre
(
step_index
)
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
():
self
.
audio_adapter_t_emb
=
self
.
audio_adapter
.
time_embedding
(
self
.
timestep_input
).
unflatten
(
1
,
(
3
,
-
1
))
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
())
def
prepare
(
self
,
image_encoder_output
=
None
):
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
3488b187
...
@@ -320,6 +320,12 @@ class WanScheduler(BaseScheduler):
...
@@ -320,6 +320,12 @@ class WanScheduler(BaseScheduler):
x_t
=
x_t
.
to
(
x
.
dtype
)
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
return
x_t
def
step_pre
(
self
,
step_index
):
super
().
step_pre
(
step_index
)
self
.
timestep_input
=
torch
.
stack
([
self
.
timesteps
[
self
.
step_index
]])
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
timestep_input
=
(
self
.
mask
[
0
][:,
::
2
,
::
2
]
*
self
.
timestep_input
).
flatten
()
def
step_post
(
self
):
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
timestep
=
self
.
timesteps
[
self
.
step_index
]
timestep
=
self
.
timesteps
[
self
.
step_index
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment