Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
074d281a
Commit
074d281a
authored
Apr 09, 2023
by
William Berman
Committed by
Will Berman
Apr 10, 2023
Browse files
tests and additional scheduler fixes
parent
953c9d14
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
3 deletions
+36
-3
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+10
-1
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+10
-2
tests/schedulers/test_scheduler_dpm_multi.py
tests/schedulers/test_scheduler_dpm_multi.py
+8
-0
tests/schedulers/test_scheduler_unipc.py
tests/schedulers/test_scheduler_unipc.py
+8
-0
No files found.
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
074d281a
...
...
@@ -171,6 +171,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
...
...
@@ -181,14 +182,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self
.
num_inference_steps
=
num_inference_steps
timesteps
=
(
np
.
linspace
(
0
,
self
.
num_train_timesteps
-
1
,
num_inference_steps
+
1
)
.
round
()[::
-
1
][:
-
1
]
.
copy
()
.
astype
(
np
.
int64
)
)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_
,
unique_indices
=
np
.
unique
(
timesteps
,
return_index
=
True
)
timesteps
=
timesteps
[
np
.
sort
(
unique_indices
)]
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
)
self
.
num_inference_steps
=
len
(
timesteps
)
self
.
model_outputs
=
[
None
,
]
*
self
.
config
.
solver_order
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
074d281a
...
...
@@ -194,21 +194,29 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self
.
num_inference_steps
=
num_inference_steps
timesteps
=
(
np
.
linspace
(
0
,
self
.
num_train_timesteps
-
1
,
num_inference_steps
+
1
)
.
round
()[::
-
1
][:
-
1
]
.
copy
()
.
astype
(
np
.
int64
)
)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_
,
unique_indices
=
np
.
unique
(
timesteps
,
return_index
=
True
)
timesteps
=
timesteps
[
np
.
sort
(
unique_indices
)]
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
)
self
.
num_inference_steps
=
len
(
timesteps
)
self
.
model_outputs
=
[
None
,
]
*
self
.
config
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
if
self
.
solver_p
:
self
.
solver_p
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
self
.
solver_p
.
set_timesteps
(
self
.
num_inference_steps
,
device
=
device
)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
...
...
tests/schedulers/test_scheduler_dpm_multi.py
View file @
074d281a
...
...
@@ -243,3 +243,11 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample
=
scheduler
.
step
(
residual
,
t
,
sample
).
prev_sample
assert
sample
.
dtype
==
torch
.
float16
def
test_unique_timesteps
(
self
,
**
config
):
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
scheduler
.
config
.
num_train_timesteps
)
assert
len
(
scheduler
.
timesteps
.
unique
())
==
scheduler
.
num_inference_steps
tests/schedulers/test_scheduler_unipc.py
View file @
074d281a
...
...
@@ -229,3 +229,11 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample
=
scheduler
.
step
(
residual
,
t
,
sample
).
prev_sample
assert
sample
.
dtype
==
torch
.
float16
def
test_unique_timesteps
(
self
,
**
config
):
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
scheduler
.
config
.
num_train_timesteps
)
assert
len
(
scheduler
.
timesteps
.
unique
())
==
scheduler
.
num_inference_steps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment