Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7daa90ac
"clients/vscode:/vscode.git/clone" did not exist on "55bd4fed7da83a566dca08b0bb29dbc5929a90eb"
Unverified
Commit
7daa90ac
authored
Jun 03, 2021
by
Nicolas Hug
Committed by
GitHub
Jun 03, 2021
Browse files
Take assertExpected and check_jit_scriptable out of the TestCase class (#3947)
parent
4c0fdc61
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
121 deletions
+116
-121
test/common_utils.py
test/common_utils.py
+0
-108
test/test_models.py
test/test_models.py
+116
-13
No files found.
test/common_utils.py
View file @
7daa90ac
...
...
@@ -5,9 +5,7 @@ import contextlib
import
unittest
import
argparse
import
sys
import
io
import
torch
import
warnings
import
__main__
import
random
import
inspect
...
...
@@ -15,7 +13,6 @@ import inspect
from
numbers
import
Number
from
torch._six
import
string_classes
from
collections
import
OrderedDict
from
_utils_internal
import
get_relative_path
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -49,10 +46,6 @@ def set_rng_seed(seed):
np
.
random
.
seed
(
seed
)
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
TEST_WITH_SLOW
=
os
.
getenv
(
'PYTORCH_TEST_WITH_SLOW'
,
'0'
)
==
'1'
class
MapNestedTensorObjectImpl
(
object
):
def
__init__
(
self
,
tensor_map_fn
):
self
.
tensor_map_fn
=
tensor_map_fn
...
...
@@ -95,55 +88,6 @@ def is_iterable(obj):
class
TestCase
(
unittest
.
TestCase
):
precision
=
1e-5
def
_get_expected_file
(
self
,
name
=
None
):
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id
=
self
.
__class__
.
__module__
# Determine expected file based on environment
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
sys
.
modules
[
module_id
].
__file__
),
"expect"
)
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file
=
expected_file
=
os
.
path
.
join
(
expected_file_base
,
'ModelTester.test_'
+
name
)
expected_file
+=
"_expect.pkl"
if
not
ACCEPT
and
not
os
.
path
.
exists
(
expected_file
):
raise
RuntimeError
(
f
"No expect file exists for
{
os
.
path
.
basename
(
expected_file
)
}
in
{
expected_file
}
; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return
expected_file
def
assertExpected
(
self
,
output
,
name
,
prec
=
None
):
r
"""
Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file
=
self
.
_get_expected_file
(
name
)
if
ACCEPT
:
filename
=
{
os
.
path
.
basename
(
expected_file
)}
print
(
"Accepting updated output for {}:
\n\n
{}"
.
format
(
filename
,
output
))
torch
.
save
(
output
,
expected_file
)
MAX_PICKLE_SIZE
=
50
*
1000
# 50 KB
binary_size
=
os
.
path
.
getsize
(
expected_file
)
if
binary_size
>
MAX_PICKLE_SIZE
:
raise
RuntimeError
(
"The output for {}, is larger than 50kb"
.
format
(
filename
))
else
:
expected
=
torch
.
load
(
expected_file
)
rtol
=
atol
=
prec
or
self
.
precision
torch
.
testing
.
assert_close
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
,
check_dtype
=
False
)
def
assertEqual
(
self
,
x
,
y
,
prec
=
None
,
message
=
''
,
allow_inf
=
False
):
"""
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual
...
...
@@ -261,58 +205,6 @@ class TestCase(unittest.TestCase):
else
:
super
(
TestCase
,
self
).
assertEqual
(
x
,
y
,
message
)
def
check_jit_scriptable
(
self
,
nn_module
,
args
,
unwrapper
=
None
,
skip
=
False
):
"""
Check that a nn.Module's results in TorchScript match eager and that it
can be exported
"""
if
not
TEST_WITH_SLOW
or
skip
:
# TorchScript is not enabled, skip these tests
msg
=
"The check_jit_scriptable test for {} was skipped. "
\
"This test checks if the module's results in TorchScript "
\
"match eager and that it can be exported. To run these "
\
"tests make sure you set the environment variable "
\
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
\
"manually skipped."
.
format
(
nn_module
.
__class__
.
__name__
)
warnings
.
warn
(
msg
,
RuntimeWarning
)
return
None
sm
=
torch
.
jit
.
script
(
nn_module
)
with
freeze_rng_state
():
eager_out
=
nn_module
(
*
args
)
with
freeze_rng_state
():
script_out
=
sm
(
*
args
)
if
unwrapper
:
script_out
=
unwrapper
(
script_out
)
self
.
assertEqual
(
eager_out
,
script_out
,
prec
=
1e-4
)
self
.
assertExportImportModule
(
sm
,
args
)
return
sm
def
getExportImportCopy
(
self
,
m
):
"""
Save and load a TorchScript model
"""
buffer
=
io
.
BytesIO
()
torch
.
jit
.
save
(
m
,
buffer
)
buffer
.
seek
(
0
)
imported
=
torch
.
jit
.
load
(
buffer
)
return
imported
def
assertExportImportModule
(
self
,
m
,
args
):
"""
Check that the results of a model are the same after saving and loading
"""
m_import
=
self
.
getExportImportCopy
(
m
)
with
freeze_rng_state
():
results
=
m
(
*
args
)
with
freeze_rng_state
():
results_from_imported
=
m_import
(
*
args
)
self
.
assertEqual
(
results
,
results_from_imported
,
prec
=
3e-4
)
@
contextlib
.
contextmanager
def
freeze_rng_state
():
...
...
test/test_models.py
View file @
7daa90ac
import
os
import
io
import
sys
from
common_utils
import
TestCase
,
map_nested_tensor_object
,
freeze_rng_state
,
set_rng_seed
,
IN_CIRCLE_CI
from
common_utils
import
TestCase
,
map_nested_tensor_object
,
freeze_rng_state
,
set_rng_seed
from
_utils_internal
import
get_relative_path
from
collections
import
OrderedDict
from
itertools
import
product
import
functools
...
...
@@ -13,6 +16,9 @@ import warnings
import
pytest
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
def
get_available_classification_models
():
# TODO add a registration mechanism to torchvision.models
return
[
k
for
k
,
v
in
models
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
...
...
@@ -33,6 +39,103 @@ def get_available_video_models():
return
[
k
for
k
,
v
in
models
.
video
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
def
_get_expected_file
(
name
=
None
):
# Determine expected file based on environment
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
__file__
),
"expect"
)
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file
=
expected_file
=
os
.
path
.
join
(
expected_file_base
,
'ModelTester.test_'
+
name
)
expected_file
+=
"_expect.pkl"
if
not
ACCEPT
and
not
os
.
path
.
exists
(
expected_file
):
raise
RuntimeError
(
f
"No expect file exists for
{
os
.
path
.
basename
(
expected_file
)
}
in
{
expected_file
}
; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return
expected_file
def
_assert_expected
(
output
,
name
,
prec
):
"""Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file
=
_get_expected_file
(
name
)
if
ACCEPT
:
filename
=
{
os
.
path
.
basename
(
expected_file
)}
print
(
"Accepting updated output for {}:
\n\n
{}"
.
format
(
filename
,
output
))
torch
.
save
(
output
,
expected_file
)
MAX_PICKLE_SIZE
=
50
*
1000
# 50 KB
binary_size
=
os
.
path
.
getsize
(
expected_file
)
if
binary_size
>
MAX_PICKLE_SIZE
:
raise
RuntimeError
(
"The output for {}, is larger than 50kb"
.
format
(
filename
))
else
:
expected
=
torch
.
load
(
expected_file
)
rtol
=
atol
=
prec
torch
.
testing
.
assert_close
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
,
check_dtype
=
False
)
def
_check_jit_scriptable
(
nn_module
,
args
,
unwrapper
=
None
,
skip
=
False
):
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
def
assert_export_import_module
(
m
,
args
):
"""Check that the results of a model are the same after saving and loading"""
def
get_export_import_copy
(
m
):
"""Save and load a TorchScript model"""
buffer
=
io
.
BytesIO
()
torch
.
jit
.
save
(
m
,
buffer
)
buffer
.
seek
(
0
)
imported
=
torch
.
jit
.
load
(
buffer
)
return
imported
m_import
=
get_export_import_copy
(
m
)
with
freeze_rng_state
():
results
=
m
(
*
args
)
with
freeze_rng_state
():
results_from_imported
=
m_import
(
*
args
)
tol
=
3e-4
try
:
torch
.
testing
.
assert_close
(
results
,
results_from_imported
,
atol
=
tol
,
rtol
=
tol
)
except
pytest
.
UsageError
:
# custom check for the models that return named tuples:
# we compare field by field while ignoring None as assert_close can't handle None
for
a
,
b
in
zip
(
results
,
results_from_imported
):
if
a
is
not
None
:
torch
.
testing
.
assert_close
(
a
,
b
,
atol
=
tol
,
rtol
=
tol
)
TEST_WITH_SLOW
=
os
.
getenv
(
'PYTORCH_TEST_WITH_SLOW'
,
'0'
)
==
'1'
if
not
TEST_WITH_SLOW
or
skip
:
# TorchScript is not enabled, skip these tests
msg
=
"The check_jit_scriptable test for {} was skipped. "
\
"This test checks if the module's results in TorchScript "
\
"match eager and that it can be exported. To run these "
\
"tests make sure you set the environment variable "
\
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
\
"manually skipped."
.
format
(
nn_module
.
__class__
.
__name__
)
warnings
.
warn
(
msg
,
RuntimeWarning
)
return
None
sm
=
torch
.
jit
.
script
(
nn_module
)
with
freeze_rng_state
():
eager_out
=
nn_module
(
*
args
)
with
freeze_rng_state
():
script_out
=
sm
(
*
args
)
if
unwrapper
:
script_out
=
unwrapper
(
script_out
)
torch
.
testing
.
assert_close
(
eager_out
,
script_out
,
atol
=
1e-4
,
rtol
=
1e-4
)
assert_export_import_module
(
sm
,
args
)
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
...
...
@@ -132,16 +235,16 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x
=
torch
.
rand
(
input_shape
).
to
(
device
=
dev
)
out
=
model
(
x
)
self
.
assert
E
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
_
assert
_e
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
self
.
assertEqual
(
out
.
shape
[
-
1
],
50
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
if
dev
==
torch
.
device
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
out
=
model
(
x
)
# See autocast_flaky_numerics comment at top of file.
if
name
not
in
autocast_flaky_numerics
:
self
.
assert
E
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
_
assert
_e
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
self
.
assertEqual
(
out
.
shape
[
-
1
],
50
)
def
_test_segmentation_model
(
self
,
name
,
dev
):
...
...
@@ -166,12 +269,12 @@ class ModelTester(TestCase):
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self
.
assert
E
xpected
(
out
.
cpu
(),
name
,
prec
=
prec
)
_
assert
_e
xpected
(
out
.
cpu
(),
name
,
prec
=
prec
)
except
AssertionError
:
# Unfortunately some segmentation models are flaky with autocast
# so instead of validating the probability scores, check that the class
# predictions match.
expected_file
=
self
.
_get_expected_file
(
name
)
expected_file
=
_get_expected_file
(
name
)
expected
=
torch
.
load
(
expected_file
)
torch
.
testing
.
assert_close
(
out
.
argmax
(
dim
=
1
),
expected
.
argmax
(
dim
=
1
),
rtol
=
prec
,
atol
=
prec
)
return
False
# Partial validation performed
...
...
@@ -180,7 +283,7 @@ class ModelTester(TestCase):
full_validation
=
check_out
(
out
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
if
dev
==
torch
.
device
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
...
...
@@ -248,13 +351,13 @@ class ModelTester(TestCase):
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self
.
assert
E
xpected
(
output
,
name
,
prec
=
prec
)
_
assert
_e
xpected
(
output
,
name
,
prec
=
prec
)
except
AssertionError
:
# Unfortunately detection models are flaky due to the unstable sort
# in NMS. If matching across all outputs fails, use the same approach
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
# scores.
expected_file
=
self
.
_get_expected_file
(
name
)
expected_file
=
_get_expected_file
(
name
)
expected
=
torch
.
load
(
expected_file
)
torch
.
testing
.
assert_close
(
output
[
0
][
"scores"
],
expected
[
0
][
"scores"
],
rtol
=
prec
,
atol
=
prec
,
check_device
=
False
,
check_dtype
=
False
)
...
...
@@ -268,7 +371,7 @@ class ModelTester(TestCase):
return
True
# Full validation performed
full_validation
=
check_out
(
out
)
self
.
check_jit_scriptable
(
model
,
([
x
],),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
([
x
],),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
if
dev
==
torch
.
device
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
...
...
@@ -318,7 +421,7 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x
=
torch
.
rand
(
input_shape
).
to
(
device
=
dev
)
out
=
model
(
x
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
self
.
assertEqual
(
out
.
shape
[
-
1
],
50
)
if
dev
==
torch
.
device
(
"cuda"
):
...
...
@@ -398,7 +501,7 @@ class ModelTester(TestCase):
model
.
AuxLogits
=
None
model
=
model
.
eval
()
x
=
torch
.
rand
(
1
,
3
,
299
,
299
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
def
test_fasterrcnn_double
(
self
):
model
=
models
.
detection
.
fasterrcnn_resnet50_fpn
(
num_classes
=
50
,
pretrained_backbone
=
False
)
...
...
@@ -427,7 +530,7 @@ class ModelTester(TestCase):
model
.
aux2
=
None
model
=
model
.
eval
()
x
=
torch
.
rand
(
1
,
3
,
224
,
224
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
'needs GPU'
)
def
test_fasterrcnn_switch_devices
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment