Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b5bcbed7
Commit
b5bcbed7
authored
Aug 14, 2025
by
wangshankun
Browse files
重构audio的prepare_prev_latents
parent
99a6f046
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
130 additions
and
127 deletions
+130
-127
configs/audio_driven/wan_i2v_audio_offload.json
configs/audio_driven/wan_i2v_audio_offload.json
+1
-1
lightx2v/infer.py
lightx2v/infer.py
+15
-2
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+21
-9
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+23
-7
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+0
-32
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+60
-66
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+0
-1
scripts/dist_infer/run_wan22_ti2v_i2v_ulysses.sh
scripts/dist_infer/run_wan22_ti2v_i2v_ulysses.sh
+2
-2
scripts/wan/run_wan_i2v_audio.sh
scripts/wan/run_wan_i2v_audio.sh
+3
-2
scripts/wan/run_wan_i2v_audio_dist.sh
scripts/wan/run_wan_i2v_audio_dist.sh
+2
-2
scripts/wan22/run_wan22_ti2v_i2v.sh
scripts/wan22/run_wan22_ti2v_i2v.sh
+3
-3
No files found.
configs/audio_driven/wan_i2v_audio_offload.json
View file @
b5bcbed7
{
{
"infer_steps"
:
4
,
"infer_steps"
:
4
,
"target_fps"
:
16
,
"target_fps"
:
16
,
"video_duration"
:
5
,
"video_duration"
:
12
,
"audio_sr"
:
16000
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_height"
:
720
,
...
...
lightx2v/infer.py
View file @
b5bcbed7
...
@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
...
@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22MoeAudioRunner
,
WanAudioRunner
,
Wan22AudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
Wan22MoeAudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
...
@@ -39,7 +39,20 @@ def main():
...
@@ -39,7 +39,20 @@ def main():
"--model_cls"
,
"--model_cls"
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2_moe_audio"
,
"wan2.2_audio"
,
"wan2.2"
,
"wan2.2_moe_distill"
],
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2_moe_audio"
,
"wan2.2_audio"
,
"wan2.2"
,
"wan2.2_moe_distill"
,
],
default
=
"wan2.1"
,
default
=
"wan2.1"
,
)
)
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
b5bcbed7
...
@@ -145,7 +145,12 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -145,7 +145,12 @@ class PerceiverAttentionCA(nn.Module):
batchsize
=
len
(
x
)
batchsize
=
len
(
x
)
x
=
self
.
norm_kv
(
x
)
x
=
self
.
norm_kv
(
x
)
shift
,
scale
,
gate
=
(
t_emb
+
self
.
shift_scale_gate
).
chunk
(
3
,
dim
=
1
)
shift
,
scale
,
gate
=
(
t_emb
+
self
.
shift_scale_gate
).
chunk
(
3
,
dim
=
1
)
latents
=
self
.
norm_q
(
latents
)
*
(
1
+
scale
)
+
shift
norm_q
=
self
.
norm_q
(
latents
)
if
scale
.
shape
[
0
]
!=
norm_q
.
shape
[
0
]:
scale
=
scale
.
transpose
(
0
,
1
)
# (1, 5070, 3072)
shift
=
shift
.
transpose
(
0
,
1
)
gate
=
gate
.
transpose
(
0
,
1
)
latents
=
norm_q
*
(
1
+
scale
)
+
shift
q
=
self
.
to_q
(
latents
.
to
(
GET_DTYPE
()))
q
=
self
.
to_q
(
latents
.
to
(
GET_DTYPE
()))
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
...
@@ -222,16 +227,23 @@ class TimeEmbedding(nn.Module):
...
@@ -222,16 +227,23 @@ class TimeEmbedding(nn.Module):
self
.
act_fn
=
nn
.
SiLU
()
self
.
act_fn
=
nn
.
SiLU
()
self
.
time_proj
=
nn
.
Linear
(
dim
,
time_proj_dim
)
self
.
time_proj
=
nn
.
Linear
(
dim
,
time_proj_dim
)
def
forward
(
def
forward
(
self
,
timestep
:
torch
.
Tensor
):
self
,
# Project timestep
timestep
:
torch
.
Tensor
,
if
timestep
.
dim
()
==
2
:
):
timestep
=
self
.
timesteps_proj
(
timestep
.
squeeze
(
0
)).
unsqueeze
(
0
)
timestep
=
self
.
timesteps_proj
(
timestep
)
else
:
time_embedder_dtype
=
next
(
iter
(
self
.
time_embedder
.
parameters
())).
dtype
timestep
=
self
.
timesteps_proj
(
timestep
)
timestep
=
timestep
.
to
(
time_embedder_dtype
)
# Match dtype with time_embedder (except int8)
target_dtype
=
next
(
self
.
time_embedder
.
parameters
()).
dtype
if
timestep
.
dtype
!=
target_dtype
and
target_dtype
!=
torch
.
int8
:
timestep
=
timestep
.
to
(
target_dtype
)
# Time embedding projection
temb
=
self
.
time_embedder
(
timestep
)
temb
=
self
.
time_embedder
(
timestep
)
timestep_proj
=
self
.
time_proj
(
self
.
act_fn
(
temb
))
timestep_proj
=
self
.
time_proj
(
self
.
act_fn
(
temb
))
return
timestep_proj
return
timestep_proj
.
squeeze
(
0
)
if
timestep_proj
.
dim
()
==
3
else
timestep_proj
class
AudioAdapter
(
nn
.
Module
):
class
AudioAdapter
(
nn
.
Module
):
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
b5bcbed7
import
math
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..module_io
import
WanPreInferModuleOutput
from
..module_io
import
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
,
masks_like
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
loguru
import
logger
class
WanAudioPreInfer
(
WanPreInfer
):
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -28,13 +30,17 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -28,13 +30,17 @@ class WanAudioPreInfer(WanPreInfer):
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
infer
(
self
,
weights
,
inputs
,
positive
):
if
config
.
parallel
:
self
.
sp_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
else
:
self
.
sp_size
=
1
def
infer
(
self
,
weights
,
inputs
,
positive
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
hidden_states
=
self
.
scheduler
.
latents
hidden_states
=
self
.
scheduler
.
latents
mask1
,
mask2
=
masks_like
([
hidden_states
],
zero
=
True
,
prev_length
=
hidden_states
.
shape
[
1
])
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
(
1.
-
mask
2
[
0
])
*
prev_latents
+
mask
2
[
0
]
*
hidden_states
hidden_states
=
(
1.
0
-
prev_
mask
[
0
])
*
prev_latents
+
prev_
mask
[
0
]
*
hidden_states
else
:
else
:
prev_latents
=
prev_latents
.
unsqueeze
(
0
)
prev_latents
=
prev_latents
.
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
...
@@ -45,6 +51,16 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -45,6 +51,16 @@ class WanAudioPreInfer(WanPreInfer):
x
=
[
hidden_states
]
x
=
[
hidden_states
]
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
_
,
lat_f
,
lat_h
,
lat_w
=
self
.
scheduler
.
latents
.
shape
F
=
(
lat_f
-
1
)
*
self
.
config
.
vae_stride
[
0
]
+
1
max_seq_len
=
((
F
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
lat_h
*
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
max_seq_len
=
int
(
math
.
ceil
(
max_seq_len
/
self
.
sp_size
))
*
self
.
sp_size
temp_ts
=
(
prev_mask
[
0
][
0
][:,
::
2
,
::
2
]
*
t
).
flatten
()
temp_ts
=
torch
.
cat
([
temp_ts
,
temp_ts
.
new_ones
(
max_seq_len
-
temp_ts
.
size
(
0
))
*
t
])
t
=
temp_ts
.
unsqueeze
(
0
)
audio_dit_blocks
=
[]
audio_dit_blocks
=
[]
audio_encoder_output
=
inputs
[
"audio_encoder_output"
]
audio_encoder_output
=
inputs
[
"audio_encoder_output"
]
audio_model_input
=
{
audio_model_input
=
{
...
@@ -53,7 +69,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -53,7 +69,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep"
:
t
,
"timestep"
:
t
,
}
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
audio_dit_blocks
=
None
##Debug Drop Audio
#
audio_dit_blocks = None##Debug Drop Audio
if
positive
:
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
@@ -66,7 +82,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -66,7 +82,7 @@ class WanAudioPreInfer(WanPreInfer):
batch_size
=
len
(
x
)
batch_size
=
len
(
x
)
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
zero_padding
=
torch
.
zeros
(
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
b5bcbed7
...
@@ -4,38 +4,6 @@ import torch.distributed as dist
...
@@ -4,38 +4,6 @@ import torch.distributed as dist
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
def
masks_like
(
tensor
,
zero
=
False
,
generator
=
None
,
p
=
0.2
,
prev_length
=
1
):
assert
isinstance
(
tensor
,
list
)
out1
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
out2
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
if
prev_length
==
0
:
return
out1
,
out2
if
zero
:
if
generator
is
not
None
:
for
u
,
v
in
zip
(
out1
,
out2
):
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
u
[:,
:
prev_length
]
=
torch
.
normal
(
mean
=-
3.5
,
std
=
0.5
,
size
=
(
1
,),
device
=
u
.
device
,
generator
=
generator
).
expand_as
(
u
[:,
:
prev_length
]).
exp
()
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
else
:
u
[:,
:
prev_length
]
=
u
[:,
:
prev_length
]
v
[:,
:
prev_length
]
=
v
[:,
:
prev_length
]
else
:
for
u
,
v
in
zip
(
out1
,
out2
):
u
[:,
:
prev_length
]
=
torch
.
zeros_like
(
u
[:,
:
prev_length
])
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
return
out1
,
out2
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
[
0
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
b5bcbed7
...
@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
...
@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
,
find_torch_model_path
from
lightx2v.utils.utils
import
find_torch_model_path
,
save_to_video
,
vae_to_comfyui_image
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
@
contextmanager
@
contextmanager
def
memory_efficient_inference
():
def
memory_efficient_inference
():
...
@@ -257,9 +258,6 @@ class VideoGenerator:
...
@@ -257,9 +258,6 @@ class VideoGenerator:
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
"""Prepare previous latents for conditioning"""
if
prev_video
is
None
:
return
None
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
GET_DTYPE
()
dtype
=
GET_DTYPE
()
vae_dtype
=
torch
.
float
vae_dtype
=
torch
.
float
...
@@ -267,22 +265,29 @@ class VideoGenerator:
...
@@ -267,22 +265,29 @@ class VideoGenerator:
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
# Extract and process last frames
if
prev_video
is
not
None
:
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
# Extract and process last frames
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
# Create mask
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
prev_frame_len
=
max
((
prev_token_length
-
1
)
*
4
+
1
,
0
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
_
,
prev_mask
=
self
.
_wan22_masks_like
([
self
.
model
.
scheduler
.
latents
],
zero
=
True
,
prev_length
=
prev_latents
.
shape
[
1
])
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
if
prev_video
is
not
None
:
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
prev_frame_len
=
max
((
prev_token_length
-
1
)
*
4
+
1
,
0
)
else
:
prev_frame_len
=
0
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
...
@@ -290,6 +295,31 @@ class VideoGenerator:
...
@@ -290,6 +295,31 @@ class VideoGenerator:
return
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
return
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
def
_wan22_masks_like
(
self
,
tensor
,
zero
=
False
,
generator
=
None
,
p
=
0.2
,
prev_length
=
1
):
assert
isinstance
(
tensor
,
list
)
out1
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
out2
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
if
prev_length
==
0
:
return
out1
,
out2
if
zero
:
if
generator
is
not
None
:
for
u
,
v
in
zip
(
out1
,
out2
):
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
u
[:,
:
prev_length
]
=
torch
.
normal
(
mean
=-
3.5
,
std
=
0.5
,
size
=
(
1
,),
device
=
u
.
device
,
generator
=
generator
).
expand_as
(
u
[:,
:
prev_length
]).
exp
()
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
else
:
u
[:,
:
prev_length
]
=
u
[:,
:
prev_length
]
v
[:,
:
prev_length
]
=
v
[:,
:
prev_length
]
else
:
for
u
,
v
in
zip
(
out1
,
out2
):
u
[:,
:
prev_length
]
=
torch
.
zeros_like
(
u
[:,
:
prev_length
])
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
return
out1
,
out2
def
_wan_mask_rearrange
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_wan_mask_rearrange
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Rearrange mask for WAN model"""
"""Rearrange mask for WAN model"""
if
mask
.
ndim
==
3
:
if
mask
.
ndim
==
3
:
...
@@ -312,52 +342,7 @@ class VideoGenerator:
...
@@ -312,52 +342,7 @@ class VideoGenerator:
if
segment_idx
>
0
:
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
()
self
.
model
.
scheduler
.
reset
()
# Prepare previous latents - ALWAYS needed, even for first segment
inputs
[
"previmg_encoder_output"
]
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
GET_DTYPE
()
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
max_num_frames
=
self
.
config
.
target_video_length
if
segment_idx
==
0
:
# First segment - create zero frames
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
if
self
.
config
.
model_cls
==
'wan2.2_audio'
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
else
:
# Subsequent segments - use previous video
previmg_encoder_output
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
if
previmg_encoder_output
:
prev_latents
=
previmg_encoder_output
[
"prev_latents"
]
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
else
:
# Fallback to zeros if prepare_prev_latents fails
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
if
self
.
config
.
model_cls
==
'wan2.2_audio'
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
# Create mask for prev_latents
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
if
prev_latents
.
shape
[
-
2
:]
!=
(
height
,
width
):
logger
.
warning
(
f
"Size mismatch: prev_latents
{
prev_latents
.
shape
}
vs scheduler latents (H=
{
height
}
, W=
{
width
}
). Config tgt_h=
{
self
.
config
.
tgt_h
}
, tgt_w=
{
self
.
config
.
tgt_w
}
"
)
prev_latents
=
torch
.
nn
.
functional
.
interpolate
(
prev_latents
,
size
=
(
height
,
width
),
mode
=
"bilinear"
,
align_corners
=
False
)
# Always set previmg_encoder_output
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
# Run inference loop
# Run inference loop
if
total_steps
is
None
:
if
total_steps
is
None
:
...
@@ -373,6 +358,10 @@ class VideoGenerator:
...
@@ -373,6 +358,10 @@ class VideoGenerator:
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
self
.
model
.
scheduler
.
latents
=
(
1.0
-
prev_mask
[
0
])
*
prev_latents
+
prev_mask
[
0
]
*
self
.
model
.
scheduler
.
latents
if
self
.
progress_callback
:
if
self
.
progress_callback
:
segment_progress
=
(
segment_idx
*
total_steps
+
step_index
+
1
)
/
(
self
.
total_segments
*
total_steps
)
segment_progress
=
(
segment_idx
*
total_steps
+
step_index
+
1
)
/
(
self
.
total_segments
*
total_steps
)
...
@@ -396,6 +385,11 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -396,6 +385,11 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
_video_generator
=
None
self
.
_video_generator
=
None
self
.
_audio_preprocess
=
None
self
.
_audio_preprocess
=
None
if
self
.
seq_p_group
is
None
:
self
.
sp_size
=
1
else
:
self
.
sp_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
def
initialize
(
self
):
def
initialize
(
self
):
"""Initialize all models once for multiple runs"""
"""Initialize all models once for multiple runs"""
...
@@ -620,7 +614,6 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -620,7 +614,6 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_transformer
(
self
):
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
)
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
)
logger
.
info
(
f
"Loaded base model:
{
self
.
config
.
model_path
}
"
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
...
@@ -695,7 +688,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -695,7 +688,7 @@ class WanAudioRunner(WanRunner): # type:ignore
num_channels_latents
=
16
num_channels_latents
=
16
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
num_channels_latents
=
self
.
config
.
num_channels_latents
num_channels_latents
=
self
.
config
.
num_channels_latents
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
self
.
config
.
target_shape
=
(
num_channels_latents
,
num_channels_latents
,
...
@@ -813,6 +806,7 @@ class Wan22AudioRunner(WanAudioRunner):
...
@@ -813,6 +806,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_decoder
=
self
.
load_vae_decoder
()
vae_decoder
=
self
.
load_vae_decoder
()
return
vae_encoder
,
vae_decoder
return
vae_encoder
,
vae_decoder
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
b5bcbed7
...
@@ -7,7 +7,6 @@ import torch.nn.functional as F
...
@@ -7,7 +7,6 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
loguru
import
logger
__all__
=
[
__all__
=
[
"Wan2_2_VAE"
,
"Wan2_2_VAE"
,
...
...
scripts/dist_infer/run_wan22_ti2v_i2v_ulysses.sh
View file @
b5bcbed7
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
"/home/wangshankun/code/LightX2V"
model_path
=
model_path
=
"/data/nvme0/gushiqiao/models/Wan2.2-R2V812-Audio-5B"
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
...
...
scripts/wan/run_wan_i2v_audio.sh
View file @
b5bcbed7
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
"/home/wangshankun/code/LightX2V"
model_path
=
model_path
=
"/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P"
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
...
@@ -11,6 +11,7 @@ source ${lightx2v_path}/scripts/base/base.sh
...
@@ -11,6 +11,7 @@ source ${lightx2v_path}/scripts/base/base.sh
export
TORCH_CUDA_ARCH_LIST
=
"9.0"
export
TORCH_CUDA_ARCH_LIST
=
"9.0"
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
SENSITIVE_LAYER_DTYPE
=
None
export
SENSITIVE_LAYER_DTYPE
=
None
...
...
scripts/wan/run_wan_i2v_audio_dist.sh
View file @
b5bcbed7
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
"/home/wangshankun/code/LightX2V"
model_path
=
model_path
=
"/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P"
# set environment variables
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
source
${
lightx2v_path
}
/scripts/base/base.sh
...
...
scripts/wan22/run_wan22_ti2v_i2v.sh
View file @
b5bcbed7
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
"/home/wangshankun/code/LightX2V"
model_path
=
model_path
=
"/data/nvme0/gushiqiao/models/official_models/wan2.2/Wan2.2-TI2V-5B"
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
# set environment variables
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
source
${
lightx2v_path
}
/scripts/base/base.sh
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
python
-m
lightx2v.infer
\
--model_cls
wan2.2
\
--model_cls
wan2.2
\
--task
i2v
\
--task
i2v
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment