Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6af19588
"vscode:/vscode.git/clone" did not exist on "f755fd260aa454d3d0526ecad782bcdc6440ec41"
Commit
6af19588
authored
Jul 17, 2025
by
Zhuguanyu Wu
Committed by
GitHub
Jul 17, 2025
Browse files
Use CM scheduler for distill_models as default (#127)
parent
d594d5ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
33 deletions
+24
-33
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+24
-33
No files found.
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
6af19588
...
...
@@ -12,6 +12,10 @@ class WanStepDistillScheduler(WanScheduler):
self
.
infer_steps
=
len
(
self
.
denoising_step_list
)
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
sigma_max
=
1.0
self
.
sigma_min
=
0.0
def
prepare
(
self
,
image_encoder_output
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
...
...
@@ -23,43 +27,30 @@ class WanStepDistillScheduler(WanScheduler):
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
]
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_denoising_timesteps
(
device
=
self
.
device
)
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
s
elf
.
timesteps
=
torch
.
tensor
(
self
.
denoising_step_list
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
sigmas
=
torch
.
cat
([
self
.
timesteps
/
self
.
num_train_timesteps
,
torch
.
tensor
([
0.0
],
device
=
device
)])
self
.
sigmas
=
self
.
s
igmas
.
to
(
"cpu"
)
self
.
infer_
steps
=
len
(
self
.
timesteps
)
s
igma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
self
.
sigmas
=
torch
.
linspace
(
sigma_start
,
self
.
sigma_min
,
self
.
num_train_timesteps
+
1
)[:
-
1
]
self
.
sigmas
=
self
.
s
ample_shift
*
self
.
sigmas
/
(
1
+
(
self
.
sample_shift
-
1
)
*
self
.
sigmas
)
self
.
time
steps
=
self
.
sigmas
*
self
.
num_train_
timesteps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
denoising_step_index
=
[
self
.
num_train_timesteps
-
x
for
x
in
self
.
denoising_step_list
]
self
.
timesteps
=
self
.
timesteps
[
self
.
denoising_step_index
].
to
(
device
)
self
.
sigmas
=
self
.
sigmas
[
self
.
denoising_step_index
].
to
(
"cpu"
)
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
def
add_noise
(
self
,
original_samples
,
noise
,
sigma
):
sample
=
(
1
-
sigma
)
*
original_samples
+
sigma
*
noise
return
sample
.
type_as
(
noise
)
def
step_post
(
self
):
flow_pred
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sigma
=
self
.
sigmas
[
self
.
step_index
].
item
()
noisy_image_or_video
=
self
.
latents
.
to
(
torch
.
float32
)
-
sigma
*
flow_pred
if
self
.
step_index
<
self
.
infer_steps
-
1
:
sigma
=
self
.
sigmas
[
self
.
step_index
+
1
].
item
()
noisy_image_or_video
=
self
.
add_noise
(
noisy_image_or_video
,
torch
.
randn_like
(
noisy_image_or_video
),
self
.
sigmas
[
self
.
step_index
+
1
].
item
())
self
.
latents
=
noisy_image_or_video
.
to
(
self
.
latents
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment