Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
5ec2b691
Commit
5ec2b691
authored
Jul 17, 2025
by
sandy
Committed by
GitHub
Jul 17, 2025
Browse files
Merge pull request #135 from ModelTC/audio_r2v
Audio r2v v2
parents
6d07a72e
e08c4f90
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
402 additions
and
21 deletions
+402
-21
configs/audio_driven/wan_i2v_audio.json
configs/audio_driven/wan_i2v_audio.json
+4
-3
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+5
-2
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+88
-11
lightx2v/models/schedulers/wan/audio/audio.py
lightx2v/models/schedulers/wan/audio/audio.py
+145
-0
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+150
-0
scripts/wan/run_wan_i2v_audio.sh
scripts/wan/run_wan_i2v_audio.sh
+3
-2
tools/extract/convert_vigen_to_x2v_lora.py
tools/extract/convert_vigen_to_x2v_lora.py
+7
-3
No files found.
configs/audio_driven/wan_i2v_audio.json
View file @
5ec2b691
{
"infer_steps"
:
5
,
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
1
2
,
"video_duration"
:
1
6
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
480
,
...
...
@@ -13,5 +13,6 @@
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
"cpu_offload"
:
false
,
"use_tiling_vae"
:
true
}
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
5ec2b691
...
...
@@ -24,13 +24,15 @@ class WanAudioPreInfer(WanPreInfer):
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
inputs
,
positive
):
ltnt_
channel
=
self
.
scheduler
.
latents
.
size
(
0
)
ltnt_
frames
=
self
.
scheduler
.
latents
.
size
(
1
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
torch
.
cat
([
hidden_states
[:,
:
ltnt_channel
],
prev_latents
,
prev_mask
],
dim
=
1
)
# hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
# print(f"{prev_mask.shape}, {hidden_states.shape}, {prev_latents.shape},{prev_latents[:, :, :ltnt_frames].shape}")
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
[:,
:,
:
ltnt_frames
]],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
...
...
@@ -44,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep"
:
t
,
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
##audio_dit_blocks = None##Debug Drop Audio
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
5ec2b691
...
...
@@ -18,6 +18,9 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
from
loguru
import
logger
import
torch.distributed
as
dist
from
einops
import
rearrange
...
...
@@ -33,6 +36,45 @@ import warnings
from
typing
import
Optional
,
Tuple
,
Union
def
add_mask_to_frames
(
frames
:
np
.
ndarray
,
mask_rate
:
float
=
0.1
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
mask_rate
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
mask_rate
frames
=
frames
*
mask
return
frames
def
add_noise_to_frames
(
frames
:
np
.
ndarray
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
noise_mean
is
None
or
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
noise_mean
,
scale
=
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
frames
=
frames
+
noise
return
frames
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
ori_ar
=
ori_h
/
ori_w
...
...
@@ -75,7 +117,12 @@ def adaptive_resize(img):
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
target_h
,
target_w
=
480
,
832
if
ori_ratio
<
1.0
:
target_h
,
target_w
=
480
,
832
elif
ori_ratio
==
1.0
:
target_h
,
target_w
=
480
,
480
else
:
target_h
,
target_w
=
832
,
480
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
target_h
,
target_w
=
resolution
...
...
@@ -253,6 +300,10 @@ class WanAudioRunner(WanRunner):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
init_scheduler
(
self
):
scheduler
=
EulerSchedulerTimestepFix
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
...
...
@@ -372,6 +423,18 @@ class WanAudioRunner(WanRunner):
audio_frame_rate
=
audio_sr
/
fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
wan_mask_rearrange
(
mask
:
torch
.
Tensor
):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
# 4, T // 4, H, W
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
...
...
@@ -427,7 +490,14 @@ class WanAudioRunner(WanRunner):
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
...
...
@@ -438,7 +508,14 @@ class WanAudioRunner(WanRunner):
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
...
...
@@ -452,11 +529,11 @@ class WanAudioRunner(WanRunner):
if
prev_latents
is
not
None
:
ltnt_channel
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
bs
=
1
prev_mask
=
torch
.
zeros
((
bs
,
1
,
nframe
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
if
prev_
len
>
0
:
prev_mask
[:,
:,
:
prev_len
]
=
1.
0
#
bs = 1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_
mask
=
torch
.
zeros
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_len
:
]
=
0
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
,
...
...
@@ -483,13 +560,13 @@ class WanAudioRunner(WanRunner):
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
...
...
lightx2v/models/schedulers/wan/audio/audio.py
0 → 100644
View file @
5ec2b691
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
.utils
import
unsqueeze_to_ndim
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
get_timesteps
(
num_steps
,
max_steps
:
int
=
1000
):
return
np
.
linspace
(
max_steps
,
0
,
num_steps
+
1
,
dtype
=
np
.
float32
)
def
timestep_shift
(
timesteps
,
shift
:
float
=
1.0
):
return
shift
*
timesteps
/
(
1
+
(
shift
-
1
)
*
timesteps
)
class
FlowMatchEulerDiscreteScheduler
(
FlowMatchEulerDiscreteSchedulerBase
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
init_noise_sigma
=
1.0
def
add_noise
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
):
dtype
=
x0
.
dtype
device
=
x0
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
sigma
=
unsqueeze_to_ndim
(
sigma
,
x0
.
ndim
)
xt
=
x0
.
float
()
*
(
1
-
sigma
)
+
noise
.
float
()
*
sigma
return
xt
.
to
(
dtype
)
def
get_velocity
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
|
None
=
None
):
return
noise
-
x0
def
velocity_loss_to_x_loss
(
self
,
v_loss
:
Tensor
,
timesteps
:
Tensor
):
device
=
v_loss
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
return
v_loss
.
float
()
*
(
sigma
**
2
)
class
EulerSchedulerTimestepFix
(
FlowMatchEulerDiscreteScheduler
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
step_index
=
0
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
noise_pred
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_shift
(
self
,
shift
:
float
=
1.0
):
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
timestep_shift
(
self
.
sigmas
,
shift
=
shift
)
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
timesteps
=
get_timesteps
(
num_steps
=
infer_steps
,
max_steps
=
self
.
num_train_timesteps
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
device
or
self
.
device
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
set_shift
(
self
.
sample_shift
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
else
:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
set_timesteps
(
infer_steps
=
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
timestep
=
self
.
timesteps
[
self
.
step_index
]
sample
=
self
.
latents
.
to
(
torch
.
float32
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
sample
=
sample
.
to
(
torch
.
float32
)
# pyright: ignore
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
_step_index
+=
1
return
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
lightx2v/models/schedulers/wan/audio/scheduler.py
0 → 100755
View file @
5ec2b691
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
loguru
import
logger
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
unsqueeze_to_ndim
(
in_tensor
:
Tensor
,
tgt_n_dim
:
int
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
warnings
.
warn
(
f
"the given tensor of shape
{
in_tensor
.
shape
}
is expected to unsqueeze to
{
tgt_n_dim
}
, the original tensor will be returned"
)
return
in_tensor
if
in_tensor
.
ndim
<
tgt_n_dim
:
in_tensor
=
in_tensor
[(...,)
+
(
None
,)
*
(
tgt_n_dim
-
in_tensor
.
ndim
)]
return
in_tensor
class
EulerSchedulerTimestepFix
(
BaseScheduler
):
def
__init__
(
self
,
config
,
**
kwargs
):
# super().__init__(**kwargs)
self
.
init_noise_sigma
=
1.0
self
.
config
=
config
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
solver_order
=
2
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
step_index
=
None
self
.
noise_pred
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
sigmas
=
np
.
linspace
(
self
.
sigma_max
,
self
.
sigma_min
,
infer_steps
+
1
).
copy
()[:
-
1
]
if
shift
is
None
:
shift
=
self
.
shift
sigmas
=
shift
*
sigmas
/
(
1
+
(
shift
-
1
)
*
sigmas
)
sigma_last
=
0
timesteps
=
sigmas
*
self
.
num_train_timesteps
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
config
.
task
in
[
"t2v"
]:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sample
=
sample
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
latents
=
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
self
.
timestep_list
=
[
None
]
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
scripts/wan/run_wan_i2v_audio.sh
View file @
5ec2b691
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
lora_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
...
...
@@ -42,5 +44,4 @@ python -m lightx2v.infer \
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
\
--lora_path
${
lora_path
}
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
tools/extract/convert_vigen_to_x2v_lora.py
View file @
5ec2b691
...
...
@@ -5,12 +5,13 @@
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
###
import
torch
from
safetensors.torch
import
save_file
import
sys
import
os
from
safetensors.torch
import
save_file
from
safetensors.torch
import
load_file
if
len
(
sys
.
argv
)
!=
3
:
print
(
"用法: python convert_lora.py <输入文件
.pt
> <输出文件.safetensors>"
)
print
(
"用法: python convert_lora.py <输入文件> <输出文件.safetensors>"
)
sys
.
exit
(
1
)
ckpt_path
=
sys
.
argv
[
1
]
...
...
@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path):
print
(
f
"❌ 输入文件不存在:
{
ckpt_path
}
"
)
sys
.
exit
(
1
)
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
ckpt_path
.
endswith
(
".safetensors"
):
state_dict
=
load_file
(
ckpt_path
)
else
:
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
"state_dict"
in
state_dict
:
state_dict
=
state_dict
[
"state_dict"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment