Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
21049e90
Unverified
Commit
21049e90
authored
May 24, 2021
by
Nicolas Hug
Committed by
GitHub
May 24, 2021
Browse files
Use torch.testing.assert_close in test_models.py (#3879)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
b96d381c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
5 deletions
+4
-5
test/test_models.py
test/test_models.py
+4
-5
No files found.
test/test_models.py
View file @
21049e90
...
@@ -120,7 +120,7 @@ class ModelTester(TestCase):
...
@@ -120,7 +120,7 @@ class ModelTester(TestCase):
# predictions match.
# predictions match.
expected_file
=
self
.
_get_expected_file
(
name
)
expected_file
=
self
.
_get_expected_file
(
name
)
expected
=
torch
.
load
(
expected_file
)
expected
=
torch
.
load
(
expected_file
)
self
.
assert
Equal
(
out
.
argmax
(
dim
=
1
),
expected
.
argmax
(
dim
=
1
),
prec
=
prec
)
torch
.
testing
.
assert
_close
(
out
.
argmax
(
dim
=
1
),
expected
.
argmax
(
dim
=
1
),
rtol
=
prec
,
atol
=
prec
)
return
False
# Partial validation performed
return
False
# Partial validation performed
return
True
# Full validation performed
return
True
# Full validation performed
...
@@ -205,7 +205,8 @@ class ModelTester(TestCase):
...
@@ -205,7 +205,8 @@ class ModelTester(TestCase):
# scores.
# scores.
expected_file
=
self
.
_get_expected_file
(
name
)
expected_file
=
self
.
_get_expected_file
(
name
)
expected
=
torch
.
load
(
expected_file
)
expected
=
torch
.
load
(
expected_file
)
self
.
assertEqual
(
output
[
0
][
"scores"
],
expected
[
0
][
"scores"
],
prec
=
prec
)
torch
.
testing
.
assert_close
(
output
[
0
][
"scores"
],
expected
[
0
][
"scores"
],
rtol
=
prec
,
atol
=
prec
,
check_device
=
False
,
check_dtype
=
False
)
# Note: Fmassa proposed turning off NMS by adapting the threshold
# Note: Fmassa proposed turning off NMS by adapting the threshold
# and then using the Hungarian algorithm as in DETR to find the
# and then using the Hungarian algorithm as in DETR to find the
...
@@ -301,10 +302,8 @@ class ModelTester(TestCase):
...
@@ -301,10 +302,8 @@ class ModelTester(TestCase):
model2
.
eval
()
model2
.
eval
()
out2
=
model2
(
x
)
out2
=
model2
(
x
)
max_diff
=
(
out1
-
out2
).
abs
().
max
()
self
.
assertTrue
(
num_params
==
num_grad
)
self
.
assertTrue
(
num_params
==
num_grad
)
self
.
assertTrue
(
max_diff
<
1e-5
)
torch
.
testing
.
assert_close
(
out1
,
out2
,
rtol
=
0.0
,
atol
=
1e-5
)
def
test_resnet_dilation
(
self
):
def
test_resnet_dilation
(
self
):
# TODO improve tests to also check that each layer has the right dimensionality
# TODO improve tests to also check that each layer has the right dimensionality
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment