Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1997b908
Commit
1997b908
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
image->sample in schedule tests
parent
b2274ece
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
67 deletions
+67
-67
tests/test_scheduler.py
tests/test_scheduler.py
+67
-67
No files found.
tests/test_scheduler.py
View file @
1997b908
...
@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase):
forward_default_kwargs
=
()
forward_default_kwargs
=
()
@
property
@
property
def
dummy_
imag
e
(
self
):
def
dummy_
sampl
e
(
self
):
batch_size
=
4
batch_size
=
4
num_channels
=
3
num_channels
=
3
height
=
8
height
=
8
width
=
8
width
=
8
imag
e
=
np
.
random
.
rand
(
batch_size
,
num_channels
,
height
,
width
)
sampl
e
=
np
.
random
.
rand
(
batch_size
,
num_channels
,
height
,
width
)
return
imag
e
return
sampl
e
@
property
@
property
def
dummy_
imag
e_deter
(
self
):
def
dummy_
sampl
e_deter
(
self
):
batch_size
=
4
batch_size
=
4
num_channels
=
3
num_channels
=
3
height
=
8
height
=
8
width
=
8
width
=
8
num_elems
=
batch_size
*
num_channels
*
height
*
width
num_elems
=
batch_size
*
num_channels
*
height
*
width
imag
e
=
np
.
arange
(
num_elems
)
sampl
e
=
np
.
arange
(
num_elems
)
image
=
imag
e
.
reshape
(
num_channels
,
height
,
width
,
batch_size
)
sample
=
sampl
e
.
reshape
(
num_channels
,
height
,
width
,
batch_size
)
image
=
imag
e
/
num_elems
sample
=
sampl
e
/
num_elems
image
=
imag
e
.
transpose
(
3
,
0
,
1
,
2
)
sample
=
sampl
e
.
transpose
(
3
,
0
,
1
,
2
)
return
imag
e
return
sampl
e
def
get_scheduler_config
(
self
):
def
get_scheduler_config
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
dummy_model
(
self
):
def
dummy_model
(
self
):
def
model
(
imag
e
,
t
,
*
args
):
def
model
(
sampl
e
,
t
,
*
args
):
return
imag
e
*
t
/
(
t
+
1
)
return
sampl
e
*
t
/
(
t
+
1
)
return
model
return
model
...
@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase):
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -90,8 +90,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -90,8 +90,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
.
update
(
forward_kwargs
)
kwargs
.
update
(
forward_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
...
@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
imag
e
,
1
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
sampl
e
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
imag
e
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sampl
e
,
1
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -132,32 +132,32 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -132,32 +132,32 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
output_0
=
scheduler
.
step
(
residual
,
imag
e
,
0
,
**
kwargs
)
output_0
=
scheduler
.
step
(
residual
,
sampl
e
,
0
,
**
kwargs
)
output_1
=
scheduler
.
step
(
residual
,
imag
e
,
1
,
**
kwargs
)
output_1
=
scheduler
.
step
(
residual
,
sampl
e
,
1
,
**
kwargs
)
self
.
assertEqual
(
output_0
.
shape
,
imag
e
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
sampl
e
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
def
test_pytorch_equal_numpy
(
self
):
def
test_pytorch_equal_numpy
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
imag
e_pt
=
torch
.
tensor
(
imag
e
)
sampl
e_pt
=
torch
.
tensor
(
sampl
e
)
residual_pt
=
0.1
*
imag
e_pt
residual_pt
=
0.1
*
sampl
e_pt
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
output
=
scheduler
.
step
(
residual
,
imag
e
,
1
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
sampl
e
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
imag
e_pt
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
sampl
e_pt
,
1
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for
variance
in
[
"fixed_small"
,
"fixed_large"
,
"other"
]:
for
variance
in
[
"fixed_small"
,
"fixed_large"
,
"other"
]:
self
.
check_over_configs
(
variance_type
=
variance
)
self
.
check_over_configs
(
variance_type
=
variance
)
def
test_clip_
imag
e
(
self
):
def
test_clip_
sampl
e
(
self
):
for
clip_sample
in
[
True
,
False
]:
for
clip_sample
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_sample
=
clip_sample
)
self
.
check_over_configs
(
clip_sample
=
clip_sample
)
...
@@ -219,23 +219,23 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -219,23 +219,23 @@ class DDPMSchedulerTest(SchedulerCommonTest):
num_trained_timesteps
=
len
(
scheduler
)
num_trained_timesteps
=
len
(
scheduler
)
model
=
self
.
dummy_model
()
model
=
self
.
dummy_model
()
imag
e
=
self
.
dummy_
imag
e_deter
sampl
e
=
self
.
dummy_
sampl
e_deter
for
t
in
reversed
(
range
(
num_trained_timesteps
)):
for
t
in
reversed
(
range
(
num_trained_timesteps
)):
# 1. predict noise residual
# 1. predict noise residual
residual
=
model
(
imag
e
,
t
)
residual
=
model
(
sampl
e
,
t
)
# 2. predict previous mean of
imag
e x_t-1
# 2. predict previous mean of
sampl
e x_t-1
pred_prev_
imag
e
=
scheduler
.
step
(
residual
,
imag
e
,
t
)
pred_prev_
sampl
e
=
scheduler
.
step
(
residual
,
sampl
e
,
t
)
if
t
>
0
:
if
t
>
0
:
noise
=
self
.
dummy_
imag
e_deter
noise
=
self
.
dummy_
sampl
e_deter
variance
=
scheduler
.
get_variance
(
t
)
**
(
0.5
)
*
noise
variance
=
scheduler
.
get_variance
(
t
)
**
(
0.5
)
*
noise
imag
e
=
pred_prev_
imag
e
+
variance
sampl
e
=
pred_prev_
sampl
e
+
variance
result_sum
=
np
.
sum
(
np
.
abs
(
imag
e
))
result_sum
=
np
.
sum
(
np
.
abs
(
sampl
e
))
result_mean
=
np
.
mean
(
np
.
abs
(
imag
e
))
result_mean
=
np
.
mean
(
np
.
abs
(
sampl
e
))
assert
abs
(
result_sum
.
item
()
-
732.9947
)
<
1e-2
assert
abs
(
result_sum
.
item
()
-
732.9947
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.9544
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.9544
)
<
1e-3
...
@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_clip_
imag
e
(
self
):
def
test_clip_
sampl
e
(
self
):
for
clip_sample
in
[
True
,
False
]:
for
clip_sample
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_sample
=
clip_sample
)
self
.
check_over_configs
(
clip_sample
=
clip_sample
)
...
@@ -308,22 +308,22 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -308,22 +308,22 @@ class DDIMSchedulerTest(SchedulerCommonTest):
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
model
=
self
.
dummy_model
()
model
=
self
.
dummy_model
()
imag
e
=
self
.
dummy_
imag
e_deter
sampl
e
=
self
.
dummy_
sampl
e_deter
for
t
in
reversed
(
range
(
num_inference_steps
)):
for
t
in
reversed
(
range
(
num_inference_steps
)):
residual
=
model
(
imag
e
,
inference_step_times
[
t
])
residual
=
model
(
sampl
e
,
inference_step_times
[
t
])
pred_prev_
imag
e
=
scheduler
.
step
(
residual
,
imag
e
,
t
,
num_inference_steps
,
eta
)
pred_prev_
sampl
e
=
scheduler
.
step
(
residual
,
sampl
e
,
t
,
num_inference_steps
,
eta
)
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
self
.
dummy_
imag
e_deter
noise
=
self
.
dummy_
sampl
e_deter
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
**
(
0.5
)
*
eta
*
noise
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
**
(
0.5
)
*
eta
*
noise
imag
e
=
pred_prev_
imag
e
+
variance
sampl
e
=
pred_prev_
sampl
e
+
variance
result_sum
=
np
.
sum
(
np
.
abs
(
imag
e
))
result_sum
=
np
.
sum
(
np
.
abs
(
sampl
e
))
result_mean
=
np
.
mean
(
np
.
abs
(
imag
e
))
result_mean
=
np
.
mean
(
np
.
abs
(
sampl
e
))
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.3524
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.3524
)
<
1e-3
...
@@ -346,8 +346,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -346,8 +346,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def
check_over_configs_pmls
(
self
,
time_step
=
0
,
**
config
):
def
check_over_configs_pmls
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
...
@@ -365,16 +365,16 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -365,16 +365,16 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
check_over_forward_pmls
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward_pmls
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
kwargs
.
update
(
forward_kwargs
)
imag
e
=
self
.
dummy_
imag
e
sampl
e
=
self
.
dummy_
sampl
e
residual
=
0.1
*
imag
e
residual
=
0.1
*
sampl
e
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
...
@@ -392,8 +392,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -392,8 +392,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
imag
e
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sampl
e
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -445,7 +445,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -445,7 +445,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler
.
set_plms_mode
()
scheduler
.
set_plms_mode
()
scheduler
.
step
(
self
.
dummy_
imag
e
,
self
.
dummy_
imag
e
,
1
,
50
)
scheduler
.
step
(
self
.
dummy_
sampl
e
,
self
.
dummy_
sampl
e
,
1
,
50
)
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -454,24 +454,24 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -454,24 +454,24 @@ class PNDMSchedulerTest(SchedulerCommonTest):
num_inference_steps
=
10
num_inference_steps
=
10
model
=
self
.
dummy_model
()
model
=
self
.
dummy_model
()
imag
e
=
self
.
dummy_
imag
e_deter
sampl
e
=
self
.
dummy_
sampl
e_deter
prk_time_steps
=
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
prk_time_steps
)):
for
t
in
range
(
len
(
prk_time_steps
)):
t_orig
=
prk_time_steps
[
t
]
t_orig
=
prk_time_steps
[
t
]
residual
=
model
(
imag
e
,
t_orig
)
residual
=
model
(
sampl
e
,
t_orig
)
imag
e
=
scheduler
.
step_prk
(
residual
,
imag
e
,
t
,
num_inference_steps
)
sampl
e
=
scheduler
.
step_prk
(
residual
,
sampl
e
,
t
,
num_inference_steps
)
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
residual
=
model
(
imag
e
,
t_orig
)
residual
=
model
(
sampl
e
,
t_orig
)
imag
e
=
scheduler
.
step_plms
(
residual
,
imag
e
,
t
,
num_inference_steps
)
sampl
e
=
scheduler
.
step_plms
(
residual
,
sampl
e
,
t
,
num_inference_steps
)
result_sum
=
np
.
sum
(
np
.
abs
(
imag
e
))
result_sum
=
np
.
sum
(
np
.
abs
(
sampl
e
))
result_mean
=
np
.
mean
(
np
.
abs
(
imag
e
))
result_mean
=
np
.
mean
(
np
.
abs
(
sampl
e
))
assert
abs
(
result_sum
.
item
()
-
199.1169
)
<
1e-2
assert
abs
(
result_sum
.
item
()
-
199.1169
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.2593
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.2593
)
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment