Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6af19588
Commit
6af19588
authored
Jul 17, 2025
by
Zhuguanyu Wu
Committed by
GitHub
Jul 17, 2025
Browse files
Use CM scheduler for distill_models as default (#127)
parent
d594d5ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
33 deletions
+24
-33
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+24
-33
No files found.
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
6af19588
...
...
@@ -12,6 +12,10 @@ class WanStepDistillScheduler(WanScheduler):
self
.
infer_steps
=
len
(
self
.
denoising_step_list
)
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
sigma_max
=
1.0
self
.
sigma_min
=
0.0
def
prepare
(
self
,
image_encoder_output
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
...
...
@@ -23,43 +27,30 @@ class WanStepDistillScheduler(WanScheduler):
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
]
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_denoising_timesteps
(
device
=
self
.
device
)
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
s
elf
.
timesteps
=
torch
.
tensor
(
self
.
denoising_step_list
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
sigmas
=
torch
.
cat
([
self
.
timesteps
/
self
.
num_train_timesteps
,
torch
.
tensor
([
0.0
],
device
=
device
)])
self
.
sigmas
=
self
.
s
igmas
.
to
(
"cpu"
)
self
.
infer_
steps
=
len
(
self
.
timesteps
)
s
igma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
self
.
sigmas
=
torch
.
linspace
(
sigma_start
,
self
.
sigma_min
,
self
.
num_train_timesteps
+
1
)[:
-
1
]
self
.
sigmas
=
self
.
s
ample_shift
*
self
.
sigmas
/
(
1
+
(
self
.
sample_shift
-
1
)
*
self
.
sigmas
)
self
.
time
steps
=
self
.
sigmas
*
self
.
num_train_
timesteps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
denoising_step_index
=
[
self
.
num_train_timesteps
-
x
for
x
in
self
.
denoising_step_list
]
self
.
timesteps
=
self
.
timesteps
[
self
.
denoising_step_index
].
to
(
device
)
self
.
sigmas
=
self
.
sigmas
[
self
.
denoising_step_index
].
to
(
"cpu"
)
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
def
add_noise
(
self
,
original_samples
,
noise
,
sigma
):
sample
=
(
1
-
sigma
)
*
original_samples
+
sigma
*
noise
return
sample
.
type_as
(
noise
)
def
step_post
(
self
):
flow_pred
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sigma
=
self
.
sigmas
[
self
.
step_index
].
item
()
noisy_image_or_video
=
self
.
latents
.
to
(
torch
.
float32
)
-
sigma
*
flow_pred
if
self
.
step_index
<
self
.
infer_steps
-
1
:
sigma
=
self
.
sigmas
[
self
.
step_index
+
1
].
item
()
noisy_image_or_video
=
self
.
add_noise
(
noisy_image_or_video
,
torch
.
randn_like
(
noisy_image_or_video
),
self
.
sigmas
[
self
.
step_index
+
1
].
item
())
self
.
latents
=
noisy_image_or_video
.
to
(
self
.
latents
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment