Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
12bfd120
Commit
12bfd120
authored
Jul 17, 2025
by
wangshankun
Browse files
Aduio:Update to R2v V2
parent
740d8d8f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
374 additions
and
10 deletions
+374
-10
configs/audio_driven/wan_i2v_audio.json
configs/audio_driven/wan_i2v_audio.json
+1
-1
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+1
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+69
-6
lightx2v/models/schedulers/wan/audio/audio.py
lightx2v/models/schedulers/wan/audio/audio.py
+145
-0
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+150
-0
scripts/wan/run_wan_i2v_audio.sh
scripts/wan/run_wan_i2v_audio.sh
+1
-0
tools/extract/convert_vigen_to_x2v_lora.py
tools/extract/convert_vigen_to_x2v_lora.py
+7
-3
No files found.
configs/audio_driven/wan_i2v_audio.json
View file @
12bfd120
...
...
@@ -11,7 +11,7 @@
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
6
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_tiling_vae"
:
true
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
12bfd120
...
...
@@ -46,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep"
:
t
,
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
##audio_dit_blocks = None##Debug Drop Audio
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
12bfd120
...
...
@@ -19,6 +19,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
from
loguru
import
logger
import
torch.distributed
as
dist
...
...
@@ -35,6 +36,45 @@ import warnings
from
typing
import
Optional
,
Tuple
,
Union
def
add_mask_to_frames
(
frames
:
np
.
ndarray
,
mask_rate
:
float
=
0.1
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
mask_rate
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
mask_rate
frames
=
frames
*
mask
return
frames
def
add_noise_to_frames
(
frames
:
np
.
ndarray
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
noise_mean
is
None
or
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
noise_mean
,
scale
=
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
frames
=
frames
+
noise
return
frames
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
ori_ar
=
ori_h
/
ori_w
...
...
@@ -77,7 +117,12 @@ def adaptive_resize(img):
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
target_h
,
target_w
=
480
,
832
if
ori_ratio
<
1.0
:
target_h
,
target_w
=
480
,
832
elif
ori_ratio
==
1.0
:
target_h
,
target_w
=
480
,
480
else
:
target_h
,
target_w
=
832
,
480
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
target_h
,
target_w
=
resolution
...
...
@@ -255,6 +300,10 @@ class WanAudioRunner(WanRunner):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
init_scheduler
(
self
):
scheduler
=
EulerSchedulerTimestepFix
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
...
...
@@ -438,7 +487,14 @@ class WanAudioRunner(WanRunner):
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
...
...
@@ -449,7 +505,14 @@ class WanAudioRunner(WanRunner):
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
...
...
@@ -494,13 +557,13 @@ class WanAudioRunner(WanRunner):
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:]
.
cpu
()
)
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
...
...
lightx2v/models/schedulers/wan/audio/audio.py
0 → 100644
View file @
12bfd120
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
.utils
import
unsqueeze_to_ndim
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
get_timesteps
(
num_steps
,
max_steps
:
int
=
1000
):
return
np
.
linspace
(
max_steps
,
0
,
num_steps
+
1
,
dtype
=
np
.
float32
)
def
timestep_shift
(
timesteps
,
shift
:
float
=
1.0
):
return
shift
*
timesteps
/
(
1
+
(
shift
-
1
)
*
timesteps
)
class
FlowMatchEulerDiscreteScheduler
(
FlowMatchEulerDiscreteSchedulerBase
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
init_noise_sigma
=
1.0
def
add_noise
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
):
dtype
=
x0
.
dtype
device
=
x0
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
sigma
=
unsqueeze_to_ndim
(
sigma
,
x0
.
ndim
)
xt
=
x0
.
float
()
*
(
1
-
sigma
)
+
noise
.
float
()
*
sigma
return
xt
.
to
(
dtype
)
def
get_velocity
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
|
None
=
None
):
return
noise
-
x0
def
velocity_loss_to_x_loss
(
self
,
v_loss
:
Tensor
,
timesteps
:
Tensor
):
device
=
v_loss
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
return
v_loss
.
float
()
*
(
sigma
**
2
)
class
EulerSchedulerTimestepFix
(
FlowMatchEulerDiscreteScheduler
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
step_index
=
0
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
noise_pred
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_shift
(
self
,
shift
:
float
=
1.0
):
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
timestep_shift
(
self
.
sigmas
,
shift
=
shift
)
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
timesteps
=
get_timesteps
(
num_steps
=
infer_steps
,
max_steps
=
self
.
num_train_timesteps
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
device
or
self
.
device
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
set_shift
(
self
.
sample_shift
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
else
:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
set_timesteps
(
infer_steps
=
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
timestep
=
self
.
timesteps
[
self
.
step_index
]
sample
=
self
.
latents
.
to
(
torch
.
float32
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
sample
=
sample
.
to
(
torch
.
float32
)
# pyright: ignore
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
_step_index
+=
1
return
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
lightx2v/models/schedulers/wan/audio/scheduler.py
0 → 100755
View file @
12bfd120
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
loguru
import
logger
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
unsqueeze_to_ndim
(
in_tensor
:
Tensor
,
tgt_n_dim
:
int
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
warnings
.
warn
(
f
"the given tensor of shape
{
in_tensor
.
shape
}
is expected to unsqueeze to
{
tgt_n_dim
}
, the original tensor will be returned"
)
return
in_tensor
if
in_tensor
.
ndim
<
tgt_n_dim
:
in_tensor
=
in_tensor
[(...,)
+
(
None
,)
*
(
tgt_n_dim
-
in_tensor
.
ndim
)]
return
in_tensor
class
EulerSchedulerTimestepFix
(
BaseScheduler
):
def
__init__
(
self
,
config
,
**
kwargs
):
# super().__init__(**kwargs)
self
.
init_noise_sigma
=
1.0
self
.
config
=
config
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
solver_order
=
2
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
step_index
=
None
self
.
noise_pred
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
sigmas
=
np
.
linspace
(
self
.
sigma_max
,
self
.
sigma_min
,
infer_steps
+
1
).
copy
()[:
-
1
]
if
shift
is
None
:
shift
=
self
.
shift
sigmas
=
shift
*
sigmas
/
(
1
+
(
shift
-
1
)
*
sigmas
)
sigma_last
=
0
timesteps
=
sigmas
*
self
.
num_train_timesteps
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
config
.
task
in
[
"t2v"
]:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sample
=
sample
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
latents
=
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
self
.
timestep_list
=
[
None
]
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
scripts/wan/run_wan_i2v_audio.sh
View file @
12bfd120
...
...
@@ -4,6 +4,7 @@
lightx2v_path
=
"/mnt/Text2Video/wangshankun/lightx2v"
model_path
=
"/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-R2V-Audio-14B-720P/"
#lora_path="/mnt/Text2Video/wuzhuguanyu/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
#lora_path="/mnt/aigc/qiuzesong/Distill/DMD2/0716lightx2v/LightX2V/tools/extract/wan_r2v_V2_14B_lora_ran32.safetensors"
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
...
...
tools/extract/convert_vigen_to_x2v_lora.py
View file @
12bfd120
...
...
@@ -5,12 +5,13 @@
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
###
import
torch
from
safetensors.torch
import
save_file
import
sys
import
os
from
safetensors.torch
import
save_file
from
safetensors.torch
import
load_file
if
len
(
sys
.
argv
)
!=
3
:
print
(
"用法: python convert_lora.py <输入文件
.pt
> <输出文件.safetensors>"
)
print
(
"用法: python convert_lora.py <输入文件> <输出文件.safetensors>"
)
sys
.
exit
(
1
)
ckpt_path
=
sys
.
argv
[
1
]
...
...
@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path):
print
(
f
"❌ 输入文件不存在:
{
ckpt_path
}
"
)
sys
.
exit
(
1
)
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
ckpt_path
.
endswith
(
".safetensors"
):
state_dict
=
load_file
(
ckpt_path
)
else
:
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
"state_dict"
in
state_dict
:
state_dict
=
state_dict
[
"state_dict"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment