Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7326623a
Unverified
Commit
7326623a
authored
Jun 10, 2021
by
Nicolas Hug
Committed by
GitHub
Jun 10, 2021
Browse files
Port test_quantized_models.py to pytest (#4034)
parent
2d6931ab
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
93 deletions
+51
-93
test/test_models.py
test/test_models.py
+51
-0
test/test_quantized_models.py
test/test_quantized_models.py
+0
-93
No files found.
test/test_models.py
View file @
7326623a
...
...
@@ -8,9 +8,11 @@ import functools
import
operator
import
torch
import
torch.nn
as
nn
import
torchvision
from
torchvision
import
models
import
pytest
import
warnings
import
traceback
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
...
...
@@ -36,6 +38,11 @@ def get_available_video_models():
return
[
k
for
k
,
v
in
models
.
video
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
def
get_available_quantizable_models
():
# TODO add a registration mechanism to torchvision.models
return
[
k
for
k
,
v
in
models
.
quantization
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
def
_get_expected_file
(
name
=
None
):
# Determine expected file based on environment
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
__file__
),
"expect"
)
...
...
@@ -617,5 +624,49 @@ def test_video_model(model_name, dev):
assert
out
.
shape
[
-
1
]
==
50
@
pytest
.
mark
.
skipif
(
not
(
'fbgemm'
in
torch
.
backends
.
quantized
.
supported_engines
and
'qnnpack'
in
torch
.
backends
.
quantized
.
supported_engines
),
reason
=
"This Pytorch Build has not been built with fbgemm and qnnpack"
)
@
pytest
.
mark
.
parametrize
(
'model_name'
,
get_available_quantizable_models
())
def
test_quantized_classification_model
(
model_name
):
defaults
=
{
'input_shape'
:
(
1
,
3
,
224
,
224
),
'pretrained'
:
False
,
'quantize'
:
True
,
}
kwargs
=
{
**
defaults
,
**
_model_params
.
get
(
model_name
,
{})}
input_shape
=
kwargs
.
pop
(
'input_shape'
)
# First check if quantize=True provides models that can run with input data
model
=
torchvision
.
models
.
quantization
.
__dict__
[
model_name
](
**
kwargs
)
x
=
torch
.
rand
(
input_shape
)
model
(
x
)
kwargs
[
'quantize'
]
=
False
for
eval_mode
in
[
True
,
False
]:
model
=
torchvision
.
models
.
quantization
.
__dict__
[
model_name
](
**
kwargs
)
if
eval_mode
:
model
.
eval
()
model
.
qconfig
=
torch
.
quantization
.
default_qconfig
else
:
model
.
train
()
model
.
qconfig
=
torch
.
quantization
.
default_qat_qconfig
model
.
fuse_model
()
if
eval_mode
:
torch
.
quantization
.
prepare
(
model
,
inplace
=
True
)
else
:
torch
.
quantization
.
prepare_qat
(
model
,
inplace
=
True
)
model
.
eval
()
torch
.
quantization
.
convert
(
model
,
inplace
=
True
)
try
:
torch
.
jit
.
script
(
model
)
except
Exception
as
e
:
tb
=
traceback
.
format_exc
()
raise
AssertionError
(
f
"model cannot be scripted. Traceback =
{
str
(
tb
)
}
"
)
from
e
if
__name__
==
'__main__'
:
pytest
.
main
([
__file__
])
test/test_quantized_models.py
deleted
100644 → 0
View file @
2d6931ab
import
torchvision
from
common_utils
import
TestCase
,
map_nested_tensor_object
from
collections
import
OrderedDict
from
itertools
import
product
import
torch
import
numpy
as
np
from
torchvision
import
models
import
unittest
import
traceback
import
random
def
set_rng_seed
(
seed
):
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
def
get_available_quantizable_models
():
# TODO add a registration mechanism to torchvision.models
return
[
k
for
k
,
v
in
models
.
quantization
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
# list of models that are not scriptable
scriptable_quantizable_models_blacklist
=
[]
@
unittest
.
skipUnless
(
'fbgemm'
in
torch
.
backends
.
quantized
.
supported_engines
and
'qnnpack'
in
torch
.
backends
.
quantized
.
supported_engines
,
"This Pytorch Build has not been built with fbgemm and qnnpack"
)
class
ModelTester
(
TestCase
):
def
check_quantized_model
(
self
,
model
,
input_shape
):
x
=
torch
.
rand
(
input_shape
)
model
(
x
)
return
def
check_script
(
self
,
model
,
name
):
if
name
in
scriptable_quantizable_models_blacklist
:
return
scriptable
=
True
msg
=
""
try
:
torch
.
jit
.
script
(
model
)
except
Exception
as
e
:
tb
=
traceback
.
format_exc
()
scriptable
=
False
msg
=
str
(
e
)
+
str
(
tb
)
self
.
assertTrue
(
scriptable
,
msg
)
def
_test_classification_model
(
self
,
name
,
input_shape
):
# First check if quantize=True provides models that can run with input data
model
=
torchvision
.
models
.
quantization
.
__dict__
[
name
](
pretrained
=
False
,
quantize
=
True
)
self
.
check_quantized_model
(
model
,
input_shape
)
for
eval_mode
in
[
True
,
False
]:
model
=
torchvision
.
models
.
quantization
.
__dict__
[
name
](
pretrained
=
False
,
quantize
=
False
)
if
eval_mode
:
model
.
eval
()
model
.
qconfig
=
torch
.
quantization
.
default_qconfig
else
:
model
.
train
()
model
.
qconfig
=
torch
.
quantization
.
default_qat_qconfig
model
.
fuse_model
()
if
eval_mode
:
torch
.
quantization
.
prepare
(
model
,
inplace
=
True
)
else
:
torch
.
quantization
.
prepare_qat
(
model
,
inplace
=
True
)
model
.
eval
()
torch
.
quantization
.
convert
(
model
,
inplace
=
True
)
self
.
check_script
(
model
,
name
)
for
model_name
in
get_available_quantizable_models
():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def
do_test
(
self
,
model_name
=
model_name
):
input_shape
=
(
1
,
3
,
224
,
224
)
if
model_name
in
[
'inception_v3'
]:
input_shape
=
(
1
,
3
,
299
,
299
)
self
.
_test_classification_model
(
model_name
,
input_shape
)
# inception_v3 was causing timeouts on circleci
# See https://github.com/pytorch/vision/issues/1857
if
model_name
not
in
[
'inception_v3'
]:
setattr
(
ModelTester
,
"test_"
+
model_name
,
do_test
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment