Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d86b6917
Commit
d86b6917
authored
Jul 22, 2025
by
wangshankun
Browse files
update audio scheduler
parent
3dc1fafb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
62 deletions
+18
-62
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+18
-62
No files found.
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
d86b6917
...
...
@@ -30,10 +30,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self
.
init_noise_sigma
=
1.0
self
.
config
=
config
self
.
latents
=
None
self
.
caching_records
=
[
True
]
*
config
.
infer_steps
self
.
flag_df
=
False
self
.
transformer_infer
=
None
self
.
solver_order
=
2
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
...
...
@@ -41,47 +37,12 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
step_index
=
None
self
.
noise_pred
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
sigmas
=
np
.
linspace
(
self
.
sigma_max
,
self
.
sigma_min
,
infer_steps
+
1
).
copy
()[:
-
1
]
if
shift
is
None
:
shift
=
self
.
shift
sigmas
=
shift
*
sigmas
/
(
1
+
(
shift
-
1
)
*
sigmas
)
sigma_last
=
0
timesteps
=
sigmas
*
self
.
num_train_timesteps
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
...
...
@@ -90,24 +51,15 @@ class EulerSchedulerTimestepFix(BaseScheduler):
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_
timesteps
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
timesteps
_ori
=
self
.
timesteps
.
clone
()
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
self
.
sample_shift
*
self
.
sigmas
/
(
1
+
(
self
.
sample_shift
-
1
)
*
self
.
sigmas
)
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
...
...
@@ -128,21 +80,25 @@ class EulerSchedulerTimestepFix(BaseScheduler):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sample
=
sample
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
# x0 = sample - model_output * sigma
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
self
.
latents
=
x_t_next
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
self
.
timestep_list
=
[
None
]
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
class
ConsistencyModelScheduler
(
EulerSchedulerTimestepFix
):
def
step_post
(
self
):
logger
.
info
(
f
"Step index:
{
self
.
step_index
}
, self.timestep:
{
self
.
timesteps
[
self
.
step_index
]
}
"
)
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
x0
=
sample
-
model_output
*
sigma
x_t_next
=
x0
*
(
1
-
sigma_next
)
+
sigma_next
*
torch
.
randn_like
(
x0
)
self
.
latents
=
x_t_next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment