Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
81700555
Unverified
Commit
81700555
authored
Feb 09, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 09, 2023
Browse files
Test some flaky detection models on float64 instead of float32 (#7204)
parent
d75a5241
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
test/test_models.py
test/test_models.py
+14
-5
No files found.
test/test_models.py
View file @
81700555
...
...
@@ -29,7 +29,7 @@ def list_model_fns(module):
return
[
get_model_builder
(
name
)
for
name
in
list_models
(
module
)]
def
_get_image
(
input_shape
,
real_image
,
device
):
def
_get_image
(
input_shape
,
real_image
,
device
,
dtype
=
None
):
"""This routine loads a real or random image based on `real_image` argument.
Currently, the real image is utilized for the following list of models:
- `retinanet_resnet50_fpn`,
...
...
@@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device):
convert_tensor
=
transforms
.
ToTensor
()
image
=
convert_tensor
(
img
)
assert
tuple
(
image
.
size
())
==
input_shape
return
image
.
to
(
device
=
device
)
return
image
.
to
(
device
=
device
,
dtype
=
dtype
)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
return
torch
.
rand
(
input_shape
).
to
(
device
=
device
)
return
torch
.
rand
(
input_shape
).
to
(
device
=
device
,
dtype
=
dtype
)
@
pytest
.
fixture
...
...
@@ -278,6 +278,11 @@ autocast_flaky_numerics = (
# tests under test_quantized_classification_model will be skipped for the following models.
quantized_flaky_models
=
(
"inception_v3"
,
"resnet50"
)
# The tests for the following detection models are flaky.
# We run those tests on float64 to avoid floating point errors.
# FIXME: we shouldn't have to do that :'/
detection_flaky_models
=
(
"keypointrcnn_resnet50_fpn"
,
"maskrcnn_resnet50_fpn"
,
"maskrcnn_resnet50_fpn_v2"
)
# The following contains configuration parameters for all models which are used by
# the _test_*_model methods.
...
...
@@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev):
"input_shape"
:
(
3
,
300
,
300
),
}
model_name
=
model_fn
.
__name__
if
model_name
in
detection_flaky_models
:
dtype
=
torch
.
float64
else
:
dtype
=
torch
.
get_default_dtype
()
kwargs
=
{
**
defaults
,
**
_model_params
.
get
(
model_name
,
{})}
input_shape
=
kwargs
.
pop
(
"input_shape"
)
real_image
=
kwargs
.
pop
(
"real_image"
,
False
)
model
=
model_fn
(
**
kwargs
)
model
.
eval
().
to
(
device
=
dev
)
x
=
_get_image
(
input_shape
=
input_shape
,
real_image
=
real_image
,
device
=
dev
)
model
.
eval
().
to
(
device
=
dev
,
dtype
=
dtype
)
x
=
_get_image
(
input_shape
=
input_shape
,
real_image
=
real_image
,
device
=
dev
,
dtype
=
dtype
)
model_input
=
[
x
]
with
torch
.
no_grad
(),
freeze_rng_state
():
out
=
model
(
model_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment