Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7daa90ac
Unverified
Commit
7daa90ac
authored
Jun 03, 2021
by
Nicolas Hug
Committed by
GitHub
Jun 03, 2021
Browse files
Take assertExpected and check_jit_scriptable out of the TestCase class (#3947)
parent
4c0fdc61
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
121 deletions
+116
-121
test/common_utils.py
test/common_utils.py
+0
-108
test/test_models.py
test/test_models.py
+116
-13
No files found.
test/common_utils.py
View file @
7daa90ac
...
...
@@ -5,9 +5,7 @@ import contextlib
import
unittest
import
argparse
import
sys
import
io
import
torch
import
warnings
import
__main__
import
random
import
inspect
...
...
@@ -15,7 +13,6 @@ import inspect
from
numbers
import
Number
from
torch._six
import
string_classes
from
collections
import
OrderedDict
from
_utils_internal
import
get_relative_path
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -49,10 +46,6 @@ def set_rng_seed(seed):
np
.
random
.
seed
(
seed
)
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
TEST_WITH_SLOW
=
os
.
getenv
(
'PYTORCH_TEST_WITH_SLOW'
,
'0'
)
==
'1'
class
MapNestedTensorObjectImpl
(
object
):
def
__init__
(
self
,
tensor_map_fn
):
self
.
tensor_map_fn
=
tensor_map_fn
...
...
@@ -95,55 +88,6 @@ def is_iterable(obj):
class
TestCase
(
unittest
.
TestCase
):
precision
=
1e-5
def
_get_expected_file
(
self
,
name
=
None
):
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id
=
self
.
__class__
.
__module__
# Determine expected file based on environment
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
sys
.
modules
[
module_id
].
__file__
),
"expect"
)
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file
=
expected_file
=
os
.
path
.
join
(
expected_file_base
,
'ModelTester.test_'
+
name
)
expected_file
+=
"_expect.pkl"
if
not
ACCEPT
and
not
os
.
path
.
exists
(
expected_file
):
raise
RuntimeError
(
f
"No expect file exists for
{
os
.
path
.
basename
(
expected_file
)
}
in
{
expected_file
}
; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return
expected_file
def
assertExpected
(
self
,
output
,
name
,
prec
=
None
):
r
"""
Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file
=
self
.
_get_expected_file
(
name
)
if
ACCEPT
:
filename
=
{
os
.
path
.
basename
(
expected_file
)}
print
(
"Accepting updated output for {}:
\n\n
{}"
.
format
(
filename
,
output
))
torch
.
save
(
output
,
expected_file
)
MAX_PICKLE_SIZE
=
50
*
1000
# 50 KB
binary_size
=
os
.
path
.
getsize
(
expected_file
)
if
binary_size
>
MAX_PICKLE_SIZE
:
raise
RuntimeError
(
"The output for {}, is larger than 50kb"
.
format
(
filename
))
else
:
expected
=
torch
.
load
(
expected_file
)
rtol
=
atol
=
prec
or
self
.
precision
torch
.
testing
.
assert_close
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
,
check_dtype
=
False
)
def
assertEqual
(
self
,
x
,
y
,
prec
=
None
,
message
=
''
,
allow_inf
=
False
):
"""
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual
...
...
@@ -261,58 +205,6 @@ class TestCase(unittest.TestCase):
else
:
super
(
TestCase
,
self
).
assertEqual
(
x
,
y
,
message
)
def
check_jit_scriptable
(
self
,
nn_module
,
args
,
unwrapper
=
None
,
skip
=
False
):
"""
Check that a nn.Module's results in TorchScript match eager and that it
can be exported
"""
if
not
TEST_WITH_SLOW
or
skip
:
# TorchScript is not enabled, skip these tests
msg
=
"The check_jit_scriptable test for {} was skipped. "
\
"This test checks if the module's results in TorchScript "
\
"match eager and that it can be exported. To run these "
\
"tests make sure you set the environment variable "
\
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
\
"manually skipped."
.
format
(
nn_module
.
__class__
.
__name__
)
warnings
.
warn
(
msg
,
RuntimeWarning
)
return
None
sm
=
torch
.
jit
.
script
(
nn_module
)
with
freeze_rng_state
():
eager_out
=
nn_module
(
*
args
)
with
freeze_rng_state
():
script_out
=
sm
(
*
args
)
if
unwrapper
:
script_out
=
unwrapper
(
script_out
)
self
.
assertEqual
(
eager_out
,
script_out
,
prec
=
1e-4
)
self
.
assertExportImportModule
(
sm
,
args
)
return
sm
def
getExportImportCopy
(
self
,
m
):
"""
Save and load a TorchScript model
"""
buffer
=
io
.
BytesIO
()
torch
.
jit
.
save
(
m
,
buffer
)
buffer
.
seek
(
0
)
imported
=
torch
.
jit
.
load
(
buffer
)
return
imported
def
assertExportImportModule
(
self
,
m
,
args
):
"""
Check that the results of a model are the same after saving and loading
"""
m_import
=
self
.
getExportImportCopy
(
m
)
with
freeze_rng_state
():
results
=
m
(
*
args
)
with
freeze_rng_state
():
results_from_imported
=
m_import
(
*
args
)
self
.
assertEqual
(
results
,
results_from_imported
,
prec
=
3e-4
)
@
contextlib
.
contextmanager
def
freeze_rng_state
():
...
...
test/test_models.py
View file @
7daa90ac
import
os
import
io
import
sys
from
common_utils
import
TestCase
,
map_nested_tensor_object
,
freeze_rng_state
,
set_rng_seed
,
IN_CIRCLE_CI
from
common_utils
import
TestCase
,
map_nested_tensor_object
,
freeze_rng_state
,
set_rng_seed
from
_utils_internal
import
get_relative_path
from
collections
import
OrderedDict
from
itertools
import
product
import
functools
...
...
@@ -13,6 +16,9 @@ import warnings
import
pytest
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
def
get_available_classification_models
():
# TODO add a registration mechanism to torchvision.models
return
[
k
for
k
,
v
in
models
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
...
...
@@ -33,6 +39,103 @@ def get_available_video_models():
return
[
k
for
k
,
v
in
models
.
video
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
def
_get_expected_file
(
name
=
None
):
# Determine expected file based on environment
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
__file__
),
"expect"
)
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file
=
expected_file
=
os
.
path
.
join
(
expected_file_base
,
'ModelTester.test_'
+
name
)
expected_file
+=
"_expect.pkl"
if
not
ACCEPT
and
not
os
.
path
.
exists
(
expected_file
):
raise
RuntimeError
(
f
"No expect file exists for
{
os
.
path
.
basename
(
expected_file
)
}
in
{
expected_file
}
; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return
expected_file
def
_assert_expected
(
output
,
name
,
prec
):
"""Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file
=
_get_expected_file
(
name
)
if
ACCEPT
:
filename
=
{
os
.
path
.
basename
(
expected_file
)}
print
(
"Accepting updated output for {}:
\n\n
{}"
.
format
(
filename
,
output
))
torch
.
save
(
output
,
expected_file
)
MAX_PICKLE_SIZE
=
50
*
1000
# 50 KB
binary_size
=
os
.
path
.
getsize
(
expected_file
)
if
binary_size
>
MAX_PICKLE_SIZE
:
raise
RuntimeError
(
"The output for {}, is larger than 50kb"
.
format
(
filename
))
else
:
expected
=
torch
.
load
(
expected_file
)
rtol
=
atol
=
prec
torch
.
testing
.
assert_close
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
,
check_dtype
=
False
)
def
_check_jit_scriptable
(
nn_module
,
args
,
unwrapper
=
None
,
skip
=
False
):
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
def
assert_export_import_module
(
m
,
args
):
"""Check that the results of a model are the same after saving and loading"""
def
get_export_import_copy
(
m
):
"""Save and load a TorchScript model"""
buffer
=
io
.
BytesIO
()
torch
.
jit
.
save
(
m
,
buffer
)
buffer
.
seek
(
0
)
imported
=
torch
.
jit
.
load
(
buffer
)
return
imported
m_import
=
get_export_import_copy
(
m
)
with
freeze_rng_state
():
results
=
m
(
*
args
)
with
freeze_rng_state
():
results_from_imported
=
m_import
(
*
args
)
tol
=
3e-4
try
:
torch
.
testing
.
assert_close
(
results
,
results_from_imported
,
atol
=
tol
,
rtol
=
tol
)
except
pytest
.
UsageError
:
# custom check for the models that return named tuples:
# we compare field by field while ignoring None as assert_close can't handle None
for
a
,
b
in
zip
(
results
,
results_from_imported
):
if
a
is
not
None
:
torch
.
testing
.
assert_close
(
a
,
b
,
atol
=
tol
,
rtol
=
tol
)
TEST_WITH_SLOW
=
os
.
getenv
(
'PYTORCH_TEST_WITH_SLOW'
,
'0'
)
==
'1'
if
not
TEST_WITH_SLOW
or
skip
:
# TorchScript is not enabled, skip these tests
msg
=
"The check_jit_scriptable test for {} was skipped. "
\
"This test checks if the module's results in TorchScript "
\
"match eager and that it can be exported. To run these "
\
"tests make sure you set the environment variable "
\
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
\
"manually skipped."
.
format
(
nn_module
.
__class__
.
__name__
)
warnings
.
warn
(
msg
,
RuntimeWarning
)
return
None
sm
=
torch
.
jit
.
script
(
nn_module
)
with
freeze_rng_state
():
eager_out
=
nn_module
(
*
args
)
with
freeze_rng_state
():
script_out
=
sm
(
*
args
)
if
unwrapper
:
script_out
=
unwrapper
(
script_out
)
torch
.
testing
.
assert_close
(
eager_out
,
script_out
,
atol
=
1e-4
,
rtol
=
1e-4
)
assert_export_import_module
(
sm
,
args
)
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
...
...
@@ -132,16 +235,16 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x
=
torch
.
rand
(
input_shape
).
to
(
device
=
dev
)
out
=
model
(
x
)
self
.
assert
E
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
_
assert
_e
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
self
.
assertEqual
(
out
.
shape
[
-
1
],
50
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
if
dev
==
torch
.
device
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
out
=
model
(
x
)
# See autocast_flaky_numerics comment at top of file.
if
name
not
in
autocast_flaky_numerics
:
self
.
assert
E
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
_
assert
_e
xpected
(
out
.
cpu
(),
name
,
prec
=
0.1
)
self
.
assertEqual
(
out
.
shape
[
-
1
],
50
)
def
_test_segmentation_model
(
self
,
name
,
dev
):
...
...
@@ -166,12 +269,12 @@ class ModelTester(TestCase):
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self
.
assert
E
xpected
(
out
.
cpu
(),
name
,
prec
=
prec
)
_
assert
_e
xpected
(
out
.
cpu
(),
name
,
prec
=
prec
)
except
AssertionError
:
# Unfortunately some segmentation models are flaky with autocast
# so instead of validating the probability scores, check that the class
# predictions match.
expected_file
=
self
.
_get_expected_file
(
name
)
expected_file
=
_get_expected_file
(
name
)
expected
=
torch
.
load
(
expected_file
)
torch
.
testing
.
assert_close
(
out
.
argmax
(
dim
=
1
),
expected
.
argmax
(
dim
=
1
),
rtol
=
prec
,
atol
=
prec
)
return
False
# Partial validation performed
...
...
@@ -180,7 +283,7 @@ class ModelTester(TestCase):
full_validation
=
check_out
(
out
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
if
dev
==
torch
.
device
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
...
...
@@ -248,13 +351,13 @@ class ModelTester(TestCase):
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self
.
assert
E
xpected
(
output
,
name
,
prec
=
prec
)
_
assert
_e
xpected
(
output
,
name
,
prec
=
prec
)
except
AssertionError
:
# Unfortunately detection models are flaky due to the unstable sort
# in NMS. If matching across all outputs fails, use the same approach
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
# scores.
expected_file
=
self
.
_get_expected_file
(
name
)
expected_file
=
_get_expected_file
(
name
)
expected
=
torch
.
load
(
expected_file
)
torch
.
testing
.
assert_close
(
output
[
0
][
"scores"
],
expected
[
0
][
"scores"
],
rtol
=
prec
,
atol
=
prec
,
check_device
=
False
,
check_dtype
=
False
)
...
...
@@ -268,7 +371,7 @@ class ModelTester(TestCase):
return
True
# Full validation performed
full_validation
=
check_out
(
out
)
self
.
check_jit_scriptable
(
model
,
([
x
],),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
([
x
],),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
if
dev
==
torch
.
device
(
"cuda"
):
with
torch
.
cuda
.
amp
.
autocast
():
...
...
@@ -318,7 +421,7 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x
=
torch
.
rand
(
input_shape
).
to
(
device
=
dev
)
out
=
model
(
x
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
self
.
assertEqual
(
out
.
shape
[
-
1
],
50
)
if
dev
==
torch
.
device
(
"cuda"
):
...
...
@@ -398,7 +501,7 @@ class ModelTester(TestCase):
model
.
AuxLogits
=
None
model
=
model
.
eval
()
x
=
torch
.
rand
(
1
,
3
,
299
,
299
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
def
test_fasterrcnn_double
(
self
):
model
=
models
.
detection
.
fasterrcnn_resnet50_fpn
(
num_classes
=
50
,
pretrained_backbone
=
False
)
...
...
@@ -427,7 +530,7 @@ class ModelTester(TestCase):
model
.
aux2
=
None
model
=
model
.
eval
()
x
=
torch
.
rand
(
1
,
3
,
224
,
224
)
self
.
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
_
check_jit_scriptable
(
model
,
(
x
,),
unwrapper
=
script_model_unwrapper
.
get
(
name
,
None
))
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
'needs GPU'
)
def
test_fasterrcnn_switch_devices
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment