Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
b2274ece
Commit
b2274ece
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
finish pndm scheduler
parent
de22d4cd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
233 additions
and
41 deletions
+233
-41
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+3
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+75
-31
tests/test_scheduler.py
tests/test_scheduler.py
+155
-7
No files found.
src/diffusers/pipelines/pipeline_pndm.py
View file @
b2274ece
...
@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
...
@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
warmup
_time_steps
=
self
.
noise_scheduler
.
get_
warmup
_time_steps
(
num_inference_steps
)
prk
_time_steps
=
self
.
noise_scheduler
.
get_
prk
_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
warmup
_time_steps
))):
for
t
in
tqdm
.
tqdm
(
range
(
len
(
prk
_time_steps
))):
t_orig
=
warmup
_time_steps
[
t
]
t_orig
=
prk
_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
residual
=
self
.
unet
(
image
,
t_orig
)
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
b2274ece
...
@@ -56,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# For now we only support F-PNDM, i.e. the runge-kutta method
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at
equations
(12)
and
(13) and the Algorithm 2.
# mainly at
formula (9),
(12)
,
(13) and the Algorithm 2.
self
.
pndm_order
=
4
self
.
pndm_order
=
4
# running values
# running values
self
.
cur_residual
=
0
self
.
cur_residual
=
0
self
.
cur_sample
=
None
self
.
cur_sample
=
None
self
.
ets
=
[]
self
.
ets
=
[]
self
.
warmup
_time_steps
=
{}
self
.
prk
_time_steps
=
{}
self
.
time_steps
=
{}
self
.
time_steps
=
{}
self
.
set_prk_mode
()
def
get_alpha
(
self
,
time_step
):
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
return
self
.
alphas
[
time_step
]
...
@@ -77,18 +78,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -77,18 +78,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
one
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
get_
warmup
_time_steps
(
self
,
num_inference_steps
):
def
get_
prk
_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
warmup
_time_steps
:
if
num_inference_steps
in
self
.
prk
_time_steps
:
return
self
.
warmup
_time_steps
[
num_inference_steps
]
return
self
.
prk
_time_steps
[
num_inference_steps
]
inference_step_times
=
list
(
range
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
num_inference_steps
))
inference_step_times
=
list
(
range
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
num_inference_steps
))
warmup
_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
prk
_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
np
.
array
([
0
,
self
.
config
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
)
self
.
warmup
_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup
_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
prk
_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
prk
_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
return
self
.
warmup
_time_steps
[
num_inference_steps
]
return
self
.
prk
_time_steps
[
num_inference_steps
]
def
get_time_steps
(
self
,
num_inference_steps
):
def
get_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
time_steps
:
if
num_inference_steps
in
self
.
time_steps
:
...
@@ -99,12 +100,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -99,12 +100,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
time_steps
[
num_inference_steps
]
return
self
.
time_steps
[
num_inference_steps
]
def
set_prk_mode
(
self
):
self
.
mode
=
"prk"
def
set_plms_mode
(
self
):
self
.
mode
=
"plms"
def
step
(
self
,
*
args
,
**
kwargs
):
if
self
.
mode
==
"prk"
:
return
self
.
step_prk
(
*
args
,
**
kwargs
)
if
self
.
mode
==
"plms"
:
return
self
.
step_plms
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
t_
prev
=
warmup
_time_steps
[
t
//
4
*
4
]
t_
orig
=
prk
_time_steps
[
t
//
4
*
4
]
t_
next
=
warmup
_time_steps
[
min
(
t
+
1
,
len
(
warmup
_time_steps
)
-
1
)]
t_
orig_prev
=
prk
_time_steps
[
min
(
t
+
1
,
len
(
prk
_time_steps
)
-
1
)]
if
t
%
4
==
0
:
if
t
%
4
==
0
:
self
.
cur_residual
+=
1
/
6
*
residual
self
.
cur_residual
+=
1
/
6
*
residual
...
@@ -118,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -118,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
self
.
cur_residual
+
1
/
6
*
residual
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
self
.
cur_residual
=
0
return
self
.
transfer
(
self
.
cur_sample
,
t_prev
,
t_next
,
residual
)
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
t_
prev
=
timesteps
[
t
]
t_
orig
=
timesteps
[
t
]
t_
next
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
t_
orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
residual
)
self
.
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
transfer
(
sample
,
t_prev
,
t_next
,
residual
)
return
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
# TODO(Patrick): clean up to be compatible with numpy and give better names
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
alphas_cump
=
self
.
alphas_cumprod
.
to
(
x
.
device
)
# Note that x_t needs to be added to both sides of the equation
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
x_delta
=
(
at_next
-
at
)
*
(
# alpha_prod_t_prev -> α_(t−δ)
(
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
# beta_prod_t -> (1 - α_t)
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
# beta_prod_t_prev -> (1 - α_(t−δ))
)
# sample -> x_t
# residual -> e_θ(x_t, t)
x_next
=
x
+
x_delta
# prev_sample -> x_(t−δ)
return
x_next
alpha_prod_t
=
self
.
get_alpha_prod
(
t_orig
+
1
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t_orig_prev
+
1
)
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
# corresponds to (α_(t−δ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
# sqrt(α_(t−δ)) / sqrt(α_t))
sample_coeff
=
(
alpha_prod_t_prev
/
alpha_prod_t
)
**
(
0.5
)
# corresponds to denominator of e_θ(x_t, t) in formula (9)
residual_denom_coeff
=
alpha_prod_t
*
beta_prod_t_prev
**
(
0.5
)
+
(
alpha_prod_t
*
beta_prod_t
*
alpha_prod_t_prev
)
**
(
0.5
)
# full formula (9)
prev_sample
=
sample_coeff
*
sample
-
(
alpha_prod_t_prev
-
alpha_prod_t
)
*
residual
/
residual_denom_coeff
return
prev_sample
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
timesteps
return
self
.
config
.
timesteps
tests/test_scheduler.py
View file @
b2274ece
...
@@ -20,7 +20,7 @@ import unittest
...
@@ -20,7 +20,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
DDIMScheduler
,
DDPMScheduler
from
diffusers
import
DDIMScheduler
,
DDPMScheduler
,
PNDMScheduler
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
@@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
.
update
(
forward_kwargs
)
kwargs
.
update
(
forward_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
image
=
self
.
dummy_image
image
=
self
.
dummy_image
residual
=
0.1
*
image
residual
=
0.1
*
image
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
@@ -159,7 +159,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -159,7 +159,7 @@ class SchedulerCommonTest(unittest.TestCase):
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
image_pt
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
image_pt
,
1
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-
5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-
4
,
"Scheduler outputs are not identical"
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
...
@@ -237,8 +237,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -237,8 +237,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
assert
result_sum
.
item
()
-
732.9947
<
1e-
3
assert
abs
(
result_sum
.
item
()
-
732.9947
)
<
1e-
2
assert
result_mean
.
item
()
-
0.9544
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.9544
)
<
1e-3
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
...
@@ -325,5 +325,153 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -325,5 +325,153 @@ class DDIMSchedulerTest(SchedulerCommonTest):
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
assert
result_sum
.
item
()
-
270.6214
<
1e-3
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
result_mean
.
item
()
-
0.3524
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.3524
)
<
1e-3
class
PNDMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
PNDMScheduler
,)
forward_default_kwargs
=
((
"num_inference_steps"
,
50
),)
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
"timesteps"
:
1000
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
}
config
.
update
(
**
kwargs
)
return
config
def
check_over_configs_pmls
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
image
=
self
.
dummy_image
residual
=
0.1
*
image
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
set_plms_mode
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
check_over_forward_pmls
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
image
=
self
.
dummy_image
residual
=
0.1
*
image
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
set_plms_mode
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_timesteps
(
self
):
for
timesteps
in
[
100
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_timesteps_pmls
(
self
):
for
timesteps
in
[
100
,
1000
]:
self
.
check_over_configs_pmls
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_betas_pmls
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
self
.
check_over_configs_pmls
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_schedules_pmls
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_time_indices
(
self
):
for
t
in
[
1
,
5
,
10
]:
self
.
check_over_forward
(
time_step
=
t
)
def
test_time_indices_pmls
(
self
):
for
t
in
[
1
,
5
,
10
]:
self
.
check_over_forward_pmls
(
time_step
=
t
)
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_inference_steps_pmls
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
self
.
check_over_forward_pmls
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_inference_pmls_no_past_residuals
(
self
):
with
self
.
assertRaises
(
ValueError
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_plms_mode
()
scheduler
.
step
(
self
.
dummy_image
,
self
.
dummy_image
,
1
,
50
)
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_inference_steps
=
10
model
=
self
.
dummy_model
()
image
=
self
.
dummy_image_deter
prk_time_steps
=
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
prk_time_steps
)):
t_orig
=
prk_time_steps
[
t
]
residual
=
model
(
image
,
t_orig
)
image
=
scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
residual
=
model
(
image
,
t_orig
)
image
=
scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
assert
abs
(
result_sum
.
item
()
-
199.1169
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.2593
)
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment