Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
5ec2b691
Commit
5ec2b691
authored
Jul 17, 2025
by
sandy
Committed by
GitHub
Jul 17, 2025
Browse files
Merge pull request #135 from ModelTC/audio_r2v
Audio r2v v2
parents
6d07a72e
e08c4f90
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
402 additions
and
21 deletions
+402
-21
configs/audio_driven/wan_i2v_audio.json
configs/audio_driven/wan_i2v_audio.json
+4
-3
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+5
-2
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+88
-11
lightx2v/models/schedulers/wan/audio/audio.py
lightx2v/models/schedulers/wan/audio/audio.py
+145
-0
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+150
-0
scripts/wan/run_wan_i2v_audio.sh
scripts/wan/run_wan_i2v_audio.sh
+3
-2
tools/extract/convert_vigen_to_x2v_lora.py
tools/extract/convert_vigen_to_x2v_lora.py
+7
-3
No files found.
configs/audio_driven/wan_i2v_audio.json
View file @
5ec2b691
{
{
"infer_steps"
:
5
,
"infer_steps"
:
4
,
"target_fps"
:
16
,
"target_fps"
:
16
,
"video_duration"
:
1
2
,
"video_duration"
:
1
6
,
"audio_sr"
:
16000
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_height"
:
480
,
...
@@ -13,5 +13,6 @@
...
@@ -13,5 +13,6 @@
"sample_guide_scale"
:
1
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
"cpu_offload"
:
false
,
"use_tiling_vae"
:
true
}
}
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
5ec2b691
...
@@ -24,13 +24,15 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -24,13 +24,15 @@ class WanAudioPreInfer(WanPreInfer):
self
.
text_len
=
config
[
"text_len"
]
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
inputs
,
positive
):
def
infer
(
self
,
weights
,
inputs
,
positive
):
ltnt_
channel
=
self
.
scheduler
.
latents
.
size
(
0
)
ltnt_
frames
=
self
.
scheduler
.
latents
.
size
(
1
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
torch
.
cat
([
hidden_states
[:,
:
ltnt_channel
],
prev_latents
,
prev_mask
],
dim
=
1
)
# hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
# print(f"{prev_mask.shape}, {hidden_states.shape}, {prev_latents.shape},{prev_latents[:, :, :ltnt_frames].shape}")
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
[:,
:,
:
ltnt_frames
]],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
x
=
[
hidden_states
]
...
@@ -44,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -44,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep"
:
t
,
"timestep"
:
t
,
}
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
##audio_dit_blocks = None##Debug Drop Audio
if
positive
:
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
5ec2b691
...
@@ -18,6 +18,9 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
...
@@ -18,6 +18,9 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
from
loguru
import
logger
from
loguru
import
logger
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
einops
import
rearrange
from
einops
import
rearrange
...
@@ -33,6 +36,45 @@ import warnings
...
@@ -33,6 +36,45 @@ import warnings
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
def
add_mask_to_frames
(
frames
:
np
.
ndarray
,
mask_rate
:
float
=
0.1
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
mask_rate
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
mask_rate
frames
=
frames
*
mask
return
frames
def
add_noise_to_frames
(
frames
:
np
.
ndarray
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
noise_mean
is
None
or
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
noise_mean
,
scale
=
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
frames
=
frames
+
noise
return
frames
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
tgt_ar
=
tgt_h
/
tgt_w
ori_ar
=
ori_h
/
ori_w
ori_ar
=
ori_h
/
ori_w
...
@@ -75,7 +117,12 @@ def adaptive_resize(img):
...
@@ -75,7 +117,12 @@ def adaptive_resize(img):
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
if
ori_ratio
<
1.0
:
target_h
,
target_w
=
480
,
832
target_h
,
target_w
=
480
,
832
elif
ori_ratio
==
1.0
:
target_h
,
target_w
=
480
,
480
else
:
target_h
,
target_w
=
832
,
480
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
target_h
,
target_w
=
resolution
target_h
,
target_w
=
resolution
...
@@ -253,6 +300,10 @@ class WanAudioRunner(WanRunner):
...
@@ -253,6 +300,10 @@ class WanAudioRunner(WanRunner):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
def
init_scheduler
(
self
):
scheduler
=
EulerSchedulerTimestepFix
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
def
load_audio_models
(
self
):
##音频特征提取器
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
...
@@ -372,6 +423,18 @@ class WanAudioRunner(WanRunner):
...
@@ -372,6 +423,18 @@ class WanAudioRunner(WanRunner):
audio_frame_rate
=
audio_sr
/
fps
audio_frame_rate
=
audio_sr
/
fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
wan_mask_rearrange
(
mask
:
torch
.
Tensor
):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
# 4, T // 4, H, W
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
# process audio
...
@@ -427,7 +490,14 @@ class WanAudioRunner(WanRunner):
...
@@ -427,7 +490,14 @@ class WanAudioRunner(WanRunner):
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
...
@@ -438,7 +508,14 @@ class WanAudioRunner(WanRunner):
...
@@ -438,7 +508,14 @@ class WanAudioRunner(WanRunner):
else
:
# 中间段满81帧带pre_latens
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
...
@@ -452,11 +529,11 @@ class WanAudioRunner(WanRunner):
...
@@ -452,11 +529,11 @@ class WanAudioRunner(WanRunner):
if
prev_latents
is
not
None
:
if
prev_latents
is
not
None
:
ltnt_channel
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
ltnt_channel
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
bs
=
1
#
bs = 1
prev_mask
=
torch
.
zeros
((
bs
,
1
,
nframe
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
frames_n
=
(
nframe
-
1
)
*
4
+
1
if
prev_
len
>
0
:
prev_
mask
=
torch
.
zeros
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
:,
:
prev_len
]
=
1.
0
prev_mask
[:,
prev_len
:
]
=
0
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
previmg_encoder_output
=
{
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
,
"prev_mask"
:
prev_mask
,
...
@@ -483,13 +560,13 @@ class WanAudioRunner(WanRunner):
...
@@ -483,13 +560,13 @@ class WanAudioRunner(WanRunner):
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
else
:
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
...
...
lightx2v/models/schedulers/wan/audio/audio.py
0 → 100644
View file @
5ec2b691
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
.utils
import
unsqueeze_to_ndim
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
get_timesteps
(
num_steps
,
max_steps
:
int
=
1000
):
return
np
.
linspace
(
max_steps
,
0
,
num_steps
+
1
,
dtype
=
np
.
float32
)
def
timestep_shift
(
timesteps
,
shift
:
float
=
1.0
):
return
shift
*
timesteps
/
(
1
+
(
shift
-
1
)
*
timesteps
)
class
FlowMatchEulerDiscreteScheduler
(
FlowMatchEulerDiscreteSchedulerBase
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
init_noise_sigma
=
1.0
def
add_noise
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
):
dtype
=
x0
.
dtype
device
=
x0
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
sigma
=
unsqueeze_to_ndim
(
sigma
,
x0
.
ndim
)
xt
=
x0
.
float
()
*
(
1
-
sigma
)
+
noise
.
float
()
*
sigma
return
xt
.
to
(
dtype
)
def
get_velocity
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
|
None
=
None
):
return
noise
-
x0
def
velocity_loss_to_x_loss
(
self
,
v_loss
:
Tensor
,
timesteps
:
Tensor
):
device
=
v_loss
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
return
v_loss
.
float
()
*
(
sigma
**
2
)
class
EulerSchedulerTimestepFix
(
FlowMatchEulerDiscreteScheduler
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
step_index
=
0
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
noise_pred
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_shift
(
self
,
shift
:
float
=
1.0
):
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
timestep_shift
(
self
.
sigmas
,
shift
=
shift
)
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
timesteps
=
get_timesteps
(
num_steps
=
infer_steps
,
max_steps
=
self
.
num_train_timesteps
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
device
or
self
.
device
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
set_shift
(
self
.
sample_shift
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
else
:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
set_timesteps
(
infer_steps
=
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
timestep
=
self
.
timesteps
[
self
.
step_index
]
sample
=
self
.
latents
.
to
(
torch
.
float32
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
sample
=
sample
.
to
(
torch
.
float32
)
# pyright: ignore
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
_step_index
+=
1
return
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
lightx2v/models/schedulers/wan/audio/scheduler.py
0 → 100755
View file @
5ec2b691
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
loguru
import
logger
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
unsqueeze_to_ndim
(
in_tensor
:
Tensor
,
tgt_n_dim
:
int
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
warnings
.
warn
(
f
"the given tensor of shape
{
in_tensor
.
shape
}
is expected to unsqueeze to
{
tgt_n_dim
}
, the original tensor will be returned"
)
return
in_tensor
if
in_tensor
.
ndim
<
tgt_n_dim
:
in_tensor
=
in_tensor
[(...,)
+
(
None
,)
*
(
tgt_n_dim
-
in_tensor
.
ndim
)]
return
in_tensor
class
EulerSchedulerTimestepFix
(
BaseScheduler
):
def
__init__
(
self
,
config
,
**
kwargs
):
# super().__init__(**kwargs)
self
.
init_noise_sigma
=
1.0
self
.
config
=
config
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
solver_order
=
2
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
step_index
=
None
self
.
noise_pred
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
sigmas
=
np
.
linspace
(
self
.
sigma_max
,
self
.
sigma_min
,
infer_steps
+
1
).
copy
()[:
-
1
]
if
shift
is
None
:
shift
=
self
.
shift
sigmas
=
shift
*
sigmas
/
(
1
+
(
shift
-
1
)
*
sigmas
)
sigma_last
=
0
timesteps
=
sigmas
*
self
.
num_train_timesteps
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
config
.
task
in
[
"t2v"
]:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sample
=
sample
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
latents
=
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
self
.
timestep_list
=
[
None
]
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
scripts/wan/run_wan_i2v_audio.sh
View file @
5ec2b691
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
model_path
=
model_path
=
lora_path
=
lora_path
=
# check section
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
cuda_devices
=
0
...
@@ -42,5 +44,4 @@ python -m lightx2v.infer \
...
@@ -42,5 +44,4 @@ python -m lightx2v.infer \
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
--lora_path
${
lora_path
}
tools/extract/convert_vigen_to_x2v_lora.py
View file @
5ec2b691
...
@@ -5,12 +5,13 @@
...
@@ -5,12 +5,13 @@
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
###
###
import
torch
import
torch
from
safetensors.torch
import
save_file
import
sys
import
sys
import
os
import
os
from
safetensors.torch
import
save_file
from
safetensors.torch
import
load_file
if
len
(
sys
.
argv
)
!=
3
:
if
len
(
sys
.
argv
)
!=
3
:
print
(
"用法: python convert_lora.py <输入文件
.pt
> <输出文件.safetensors>"
)
print
(
"用法: python convert_lora.py <输入文件> <输出文件.safetensors>"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
ckpt_path
=
sys
.
argv
[
1
]
ckpt_path
=
sys
.
argv
[
1
]
...
@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path):
...
@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path):
print
(
f
"❌ 输入文件不存在:
{
ckpt_path
}
"
)
print
(
f
"❌ 输入文件不存在:
{
ckpt_path
}
"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
ckpt_path
.
endswith
(
".safetensors"
):
state_dict
=
load_file
(
ckpt_path
)
else
:
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
"state_dict"
in
state_dict
:
if
"state_dict"
in
state_dict
:
state_dict
=
state_dict
[
"state_dict"
]
state_dict
=
state_dict
[
"state_dict"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment