Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3f1861ee
Unverified
Commit
3f1861ee
authored
Aug 21, 2022
by
Nathan Lambert
Committed by
GitHub
Aug 22, 2022
Browse files
hotfix for pdnm test (#220)
parent
6a03060c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
9 deletions
+12
-9
tests/test_scheduler.py
tests/test_scheduler.py
+12
-9
No files found.
tests/test_scheduler.py
View file @
3f1861ee
...
...
@@ -426,16 +426,18 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residuals
# copy over dummy past residuals
(must be after setting timesteps)
scheduler
.
ets
=
dummy_past_residuals
[:]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_timesteps
(
num_inference_steps
)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler
.
ets
=
dummy_past_residuals
[:]
output
=
scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
...
...
@@ -461,12 +463,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
tensor_format
=
"np"
,
**
scheduler_config
)
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
# copy over dummy past residuals
scheduler_pt
.
ets
=
dummy_past_residuals_pt
[:]
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
...
...
@@ -474,6 +472,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler_pt
.
ets
=
dummy_past_residuals_pt
[:]
output
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_prk
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
...
@@ -494,15 +496,16 @@ class PNDMSchedulerTest(SchedulerCommonTest):
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
# copy over dummy past residuals
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
scheduler
.
ets
=
dummy_past_residuals
[:]
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
scheduler
.
ets
=
dummy_past_residuals
[:]
output_0
=
scheduler
.
step_prk
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment