Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
58c321cc
Commit
58c321cc
authored
Jul 22, 2025
by
wangshankun
Browse files
fix bug: r2v prev_frame mask
parent
9b2c5b6b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
158 deletions
+14
-158
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+1
-3
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+12
-7
lightx2v/models/schedulers/wan/audio/audio.py
lightx2v/models/schedulers/wan/audio/audio.py
+0
-145
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+1
-3
No files found.
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
58c321cc
...
@@ -30,9 +30,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -30,9 +30,7 @@ class WanAudioPreInfer(WanPreInfer):
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
# hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
1
)
# print(f"{prev_mask.shape}, {hidden_states.shape}, {prev_latents.shape},{prev_latents[:, :, :ltnt_frames].shape}")
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
[:,
:,
:
ltnt_frames
]],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
x
=
[
hidden_states
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
58c321cc
...
@@ -133,8 +133,8 @@ def adaptive_resize(img):
...
@@ -133,8 +133,8 @@ def adaptive_resize(img):
def
array_to_video
(
def
array_to_video
(
image_array
:
np
.
ndarray
,
image_array
:
np
.
ndarray
,
output_path
:
str
,
output_path
:
str
,
fps
:
Union
[
int
,
float
]
=
30
,
fps
:
int
|
float
=
30
,
resolution
:
Optional
[
Union
[
T
uple
[
int
,
int
]
,
T
uple
[
float
,
float
]
]]
=
None
,
resolution
:
t
uple
[
int
,
int
]
|
t
uple
[
float
,
float
]
|
None
=
None
,
disable_log
:
bool
=
False
,
disable_log
:
bool
=
False
,
lossless
:
bool
=
True
,
lossless
:
bool
=
True
,
output_pix_fmt
:
str
=
"yuv420p"
,
output_pix_fmt
:
str
=
"yuv420p"
,
...
@@ -221,6 +221,9 @@ def array_to_video(
...
@@ -221,6 +221,9 @@ def array_to_video(
output_path
,
output_path
,
]
]
if
output_pix_fmt
is
not
None
:
command
+=
[
"-pix_fmt"
,
output_pix_fmt
]
if
not
disable_log
:
if
not
disable_log
:
print
(
f
'Running "
{
" "
.
join
(
command
)
}
"'
)
print
(
f
'Running "
{
" "
.
join
(
command
)
}
"'
)
process
=
subprocess
.
Popen
(
process
=
subprocess
.
Popen
(
...
@@ -283,7 +286,7 @@ def save_to_video(gen_lvideo, out_path, target_fps):
...
@@ -283,7 +286,7 @@ def save_to_video(gen_lvideo, out_path, target_fps):
gen_lvideo
=
(
gen_lvideo
[
0
].
cpu
().
numpy
()
*
127.5
+
127.5
).
astype
(
np
.
uint8
)
gen_lvideo
=
(
gen_lvideo
[
0
].
cpu
().
numpy
()
*
127.5
+
127.5
).
astype
(
np
.
uint8
)
gen_lvideo
=
gen_lvideo
[...,
::
-
1
].
copy
()
gen_lvideo
=
gen_lvideo
[...,
::
-
1
].
copy
()
generate_unique_path
(
out_path
)
generate_unique_path
(
out_path
)
array_to_video
(
gen_lvideo
,
output_path
=
out_path
,
fps
=
target_fps
,
lossless
=
False
)
array_to_video
(
gen_lvideo
,
output_path
=
out_path
,
fps
=
target_fps
,
lossless
=
False
,
output_pix_fmt
=
"yuv444p"
)
def
save_audio
(
def
save_audio
(
...
@@ -497,8 +500,9 @@ class WanAudioRunner(WanRunner):
...
@@ -497,8 +500,9 @@ class WanAudioRunner(WanRunner):
vae_dtype
=
torch
.
float
vae_dtype
=
torch
.
float
for
idx
in
range
(
interval_num
):
for
idx
in
range
(
interval_num
):
torch
.
manual_seed
(
42
+
idx
)
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
logger
.
info
(
f
"### manual_seed:
{
42
+
idx
}
####"
)
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"### manual_seed:
{
self
.
config
.
seed
}
####"
)
useful_length
=
-
1
useful_length
=
-
1
if
idx
==
0
:
# 第一段 Condition padding0
if
idx
==
0
:
# 第一段 Condition padding0
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
...
@@ -554,8 +558,9 @@ class WanAudioRunner(WanRunner):
...
@@ -554,8 +558,9 @@ class WanAudioRunner(WanRunner):
ltnt_channel
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
ltnt_channel
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
# bs = 1
# bs = 1
frames_n
=
(
nframe
-
1
)
*
4
+
1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_mask
=
torch
.
zeros
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
[:,
prev_len
:]
=
0
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
previmg_encoder_output
=
{
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_latents"
:
prev_latents
,
...
...
lightx2v/models/schedulers/wan/audio/audio.py
deleted
100644 → 0
View file @
9b2c5b6b
import
os
import
gc
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.utils.envs
import
*
from
diffusers.configuration_utils
import
register_to_config
from
torch
import
Tensor
from
.utils
import
unsqueeze_to_ndim
from
diffusers
import
(
FlowMatchEulerDiscreteScheduler
as
FlowMatchEulerDiscreteSchedulerBase
,
# pyright: ignore
)
def
get_timesteps
(
num_steps
,
max_steps
:
int
=
1000
):
return
np
.
linspace
(
max_steps
,
0
,
num_steps
+
1
,
dtype
=
np
.
float32
)
def
timestep_shift
(
timesteps
,
shift
:
float
=
1.0
):
return
shift
*
timesteps
/
(
1
+
(
shift
-
1
)
*
timesteps
)
class
FlowMatchEulerDiscreteScheduler
(
FlowMatchEulerDiscreteSchedulerBase
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
init_noise_sigma
=
1.0
def
add_noise
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
):
dtype
=
x0
.
dtype
device
=
x0
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
sigma
=
unsqueeze_to_ndim
(
sigma
,
x0
.
ndim
)
xt
=
x0
.
float
()
*
(
1
-
sigma
)
+
noise
.
float
()
*
sigma
return
xt
.
to
(
dtype
)
def
get_velocity
(
self
,
x0
:
Tensor
,
noise
:
Tensor
,
timesteps
:
Tensor
|
None
=
None
):
return
noise
-
x0
def
velocity_loss_to_x_loss
(
self
,
v_loss
:
Tensor
,
timesteps
:
Tensor
):
device
=
v_loss
.
device
sigma
=
timesteps
.
to
(
device
,
torch
.
float32
)
/
self
.
config
.
num_train_timesteps
return
v_loss
.
float
()
*
(
sigma
**
2
)
class
EulerSchedulerTimestepFix
(
FlowMatchEulerDiscreteScheduler
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
step_index
=
0
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
noise_pred
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_shift
(
self
,
shift
:
float
=
1.0
):
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
timestep_shift
(
self
.
sigmas
,
shift
=
shift
)
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
timesteps
=
get_timesteps
(
num_steps
=
infer_steps
,
max_steps
=
self
.
num_train_timesteps
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
device
or
self
.
device
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
set_shift
(
self
.
sample_shift
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
else
:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
set_timesteps
(
infer_steps
=
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
timestep
=
self
.
timesteps
[
self
.
step_index
]
sample
=
self
.
latents
.
to
(
torch
.
float32
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
sample
=
sample
.
to
(
torch
.
float32
)
# pyright: ignore
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
_step_index
+=
1
return
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
58c321cc
...
@@ -83,9 +83,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
...
@@ -83,9 +83,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
def
prepare
(
self
,
image_encoder_output
=
None
):
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
config
.
task
in
[
"t2v"
]:
if
self
.
config
.
task
in
[
"t2v"
]:
...
@@ -113,6 +110,7 @@ class EulerSchedulerTimestepFix(BaseScheduler):
...
@@ -113,6 +110,7 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
latents
=
(
self
.
latents
=
(
torch
.
randn
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment