Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
8bdefedf
Commit
8bdefedf
authored
Aug 14, 2025
by
wangshankun
Browse files
add ti2v audio
parent
7516ad2a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
173 additions
and
23 deletions
+173
-23
configs/audio_driven/wan22_ti2v_i2v_audio.json
configs/audio_driven/wan22_ti2v_i2v_audio.json
+29
-0
lightx2v/infer.py
lightx2v/infer.py
+2
-2
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+17
-10
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+32
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+0
-3
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+65
-7
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+6
-1
scripts/wan22/run_wan22_ti2v_i2v_audio.sh
scripts/wan22/run_wan22_ti2v_i2v_audio.sh
+22
-0
No files found.
configs/audio_driven/wan22_ti2v_i2v_audio.json
0 → 100755
View file @
8bdefedf
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
12
,
"audio_sr"
:
16000
,
"target_video_length"
:
121
,
"text_len"
:
512
,
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"offload_granularity"
:
"model"
,
"fps"
:
24
,
"use_image_encoder"
:
false
,
"lora_configs"
:
[
{
"path"
:
"/data/nvme0/models/wan_ti2v_5b_ref/20250812/model_ema.safetensors"
,
"strength"
:
0.125
}
]
}
lightx2v/infer.py
View file @
8bdefedf
...
@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
...
@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22MoeAudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22MoeAudioRunner
,
WanAudioRunner
,
Wan22AudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
...
@@ -39,7 +39,7 @@ def main():
...
@@ -39,7 +39,7 @@ def main():
"--model_cls"
,
"--model_cls"
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2_moe_audio"
,
"wan2.2"
,
"wan2.2_moe_distill"
],
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2_moe_audio"
,
"wan2.2_audio"
,
"wan2.2"
,
"wan2.2_moe_distill"
],
default
=
"wan2.1"
,
default
=
"wan2.1"
,
)
)
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
8bdefedf
...
@@ -3,8 +3,8 @@ import torch
...
@@ -3,8 +3,8 @@ import torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
,
masks_like
from
loguru
import
logger
class
WanAudioPreInfer
(
WanPreInfer
):
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -28,12 +28,18 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -28,12 +28,18 @@ class WanAudioPreInfer(WanPreInfer):
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
infer
(
self
,
weights
,
inputs
,
positive
):
def
infer
(
self
,
weights
,
inputs
,
positive
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
1
)
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
hidden_states
=
hidden_states
.
squeeze
(
0
)
hidden_states
=
self
.
scheduler
.
latents
mask1
,
mask2
=
masks_like
([
hidden_states
],
zero
=
True
,
prev_length
=
hidden_states
.
shape
[
1
])
hidden_states
=
(
1.
-
mask2
[
0
])
*
prev_latents
+
mask2
[
0
]
*
hidden_states
else
:
prev_latents
=
prev_latents
.
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
x
=
[
hidden_states
]
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
...
@@ -46,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -46,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep"
:
t
,
"timestep"
:
t
,
}
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
##
audio_dit_blocks = None##Debug Drop Audio
audio_dit_blocks
=
None
##Debug Drop Audio
if
positive
:
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
@@ -55,11 +61,11 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -55,11 +61,11 @@ class WanAudioPreInfer(WanPreInfer):
seq_len
=
self
.
scheduler
.
seq_len
seq_len
=
self
.
scheduler
.
seq_len
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
.
to
(
self
.
scheduler
.
latents
.
dtype
)
batch_size
=
len
(
x
)
batch_size
=
len
(
x
)
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
zero_padding
=
torch
.
zeros
(
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
...
@@ -77,6 +83,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -77,6 +83,7 @@ class WanAudioPreInfer(WanPreInfer):
assert
seq_lens
.
max
()
<=
seq_len
assert
seq_lens
.
max
()
<=
seq_len
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
valid_patch_length
=
x
[
0
].
size
(
0
)
valid_patch_length
=
x
[
0
].
size
(
0
)
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
8bdefedf
...
@@ -3,6 +3,38 @@ import torch
...
@@ -3,6 +3,38 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
def
masks_like
(
tensor
,
zero
=
False
,
generator
=
None
,
p
=
0.2
,
prev_length
=
1
):
assert
isinstance
(
tensor
,
list
)
out1
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
out2
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
if
prev_length
==
0
:
return
out1
,
out2
if
zero
:
if
generator
is
not
None
:
for
u
,
v
in
zip
(
out1
,
out2
):
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
u
[:,
:
prev_length
]
=
torch
.
normal
(
mean
=-
3.5
,
std
=
0.5
,
size
=
(
1
,),
device
=
u
.
device
,
generator
=
generator
).
expand_as
(
u
[:,
:
prev_length
]).
exp
()
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
else
:
u
[:,
:
prev_length
]
=
u
[:,
:
prev_length
]
v
[:,
:
prev_length
]
=
v
[:,
:
prev_length
]
else
:
for
u
,
v
in
zip
(
out1
,
out2
):
u
[:,
:
prev_length
]
=
torch
.
zeros_like
(
u
[:,
:
prev_length
])
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
return
out1
,
out2
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
[
0
]
...
...
lightx2v/models/networks/wan/model.py
View file @
8bdefedf
...
@@ -231,9 +231,6 @@ class WanModel:
...
@@ -231,9 +231,6 @@ class WanModel:
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
del
self
.
original_weight_dict
torch
.
cuda
.
empty_cache
()
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
global_src_rank
=
0
global_src_rank
=
0
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
8bdefedf
...
@@ -24,8 +24,8 @@ from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelSched
...
@@ -24,8 +24,8 @@ from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelSched
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
,
find_torch_model_path
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
@
contextmanager
@
contextmanager
def
memory_efficient_inference
():
def
memory_efficient_inference
():
...
@@ -322,7 +322,11 @@ class VideoGenerator:
...
@@ -322,7 +322,11 @@ class VideoGenerator:
if
segment_idx
==
0
:
if
segment_idx
==
0
:
# First segment - create zero frames
# First segment - create zero frames
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
if
self
.
config
.
model_cls
==
'wan2.2_audio'
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
prev_len
=
0
else
:
else
:
# Subsequent segments - use previous video
# Subsequent segments - use previous video
...
@@ -333,7 +337,10 @@ class VideoGenerator:
...
@@ -333,7 +337,10 @@ class VideoGenerator:
else
:
else
:
# Fallback to zeros if prepare_prev_latents fails
# Fallback to zeros if prepare_prev_latents fails
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
if
self
.
config
.
model_cls
==
'wan2.2_audio'
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
prev_len
=
0
# Create mask for prev_latents
# Create mask for prev_latents
...
@@ -613,7 +620,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -613,7 +620,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_transformer
(
self
):
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
)
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
)
logger
.
info
(
f
"Loaded base model:
{
self
.
config
.
model_path
}
"
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
...
@@ -673,8 +680,12 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -673,8 +680,12 @@ class WanAudioRunner(WanRunner): # type:ignore
# vae encode
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
vae_encoder_out
=
vae_encoder_out
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
else
:
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
())
return
vae_encoder_out
,
clip_encoder_out
return
vae_encoder_out
,
clip_encoder_out
...
@@ -682,6 +693,9 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -682,6 +693,9 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Set target shape for generation"""
"""Set target shape for generation"""
ret
=
{}
ret
=
{}
num_channels_latents
=
16
num_channels_latents
=
16
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
num_channels_latents
=
self
.
config
.
num_channels_latents
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
self
.
config
.
target_shape
=
(
num_channels_latents
,
num_channels_latents
,
...
@@ -755,6 +769,50 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -755,6 +769,50 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
end_run
()
self
.
end_run
()
@
RUNNER_REGISTER
(
"wan2.2_audio"
)
class
Wan22AudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_vae_decoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
vae_decoder
=
Wan2_2_VAE
(
**
vae_config
)
return
vae_decoder
def
load_vae_encoder
(
self
):
# offload config
vae_offload
=
self
.
config
.
get
(
"vae_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
))
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
"cuda"
)
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
if
self
.
config
.
task
!=
"i2v"
:
return
None
else
:
return
Wan2_2_VAE
(
**
vae_config
)
def
load_vae
(
self
):
vae_encoder
=
self
.
load_vae_encoder
()
vae_decoder
=
self
.
load_vae_decoder
()
return
vae_encoder
,
vae_decoder
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
8bdefedf
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
loguru
import
logger
__all__
=
[
__all__
=
[
"Wan2_2_VAE"
,
"Wan2_2_VAE"
,
...
@@ -256,6 +257,10 @@ class AttentionBlock(nn.Module):
...
@@ -256,6 +257,10 @@ class AttentionBlock(nn.Module):
def
patchify
(
x
,
patch_size
):
def
patchify
(
x
,
patch_size
):
if
patch_size
==
1
:
if
patch_size
==
1
:
return
x
return
x
if
x
.
dim
()
==
6
:
x
=
x
.
squeeze
(
0
)
if
x
.
dim
()
==
4
:
if
x
.
dim
()
==
4
:
x
=
rearrange
(
x
,
"b c (h q) (w r) -> b (c r q) h w"
,
q
=
patch_size
,
r
=
patch_size
)
x
=
rearrange
(
x
,
"b c (h q) (w r) -> b (c r q) h w"
,
q
=
patch_size
,
r
=
patch_size
)
elif
x
.
dim
()
==
5
:
elif
x
.
dim
()
==
5
:
...
@@ -828,7 +833,7 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
...
@@ -828,7 +833,7 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
# load checkpoint
# load checkpoint
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
weights_dict
=
load_weights
(
pretrained_path
,
cpu_offload
=
cpu_offload
)
model
.
load_state_dict
(
weights_dict
)
model
.
load_state_dict
(
weights_dict
,
assign
=
True
)
return
model
return
model
...
...
scripts/wan22/run_wan22_ti2v_i2v_audio.sh
0 → 100755
View file @
8bdefedf
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
0
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
python
-m
lightx2v.infer
\
--model_cls
wan2.2_audio
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/audio_driven/wan22_ti2v_i2v_audio.json
\
--prompt
"The video features a old lady is saying something and knitting a sweater."
\
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment