Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7326623a
Unverified
Commit
7326623a
authored
Jun 10, 2021
by
Nicolas Hug
Committed by
GitHub
Jun 10, 2021
Browse files
Port test_quantized_models.py to pytest (#4034)
parent
2d6931ab
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
93 deletions
+51
-93
test/test_models.py
test/test_models.py
+51
-0
test/test_quantized_models.py
test/test_quantized_models.py
+0
-93
No files found.
test/test_models.py
View file @
7326623a
...
@@ -8,9 +8,11 @@ import functools
...
@@ -8,9 +8,11 @@ import functools
import
operator
import
operator
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torchvision
from
torchvision
import
models
from
torchvision
import
models
import
pytest
import
pytest
import
warnings
import
warnings
import
traceback
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
ACCEPT
=
os
.
getenv
(
'EXPECTTEST_ACCEPT'
,
'0'
)
==
'1'
...
@@ -36,6 +38,11 @@ def get_available_video_models():
...
@@ -36,6 +38,11 @@ def get_available_video_models():
return
[
k
for
k
,
v
in
models
.
video
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
return
[
k
for
k
,
v
in
models
.
video
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
def
get_available_quantizable_models
():
# TODO add a registration mechanism to torchvision.models
return
[
k
for
k
,
v
in
models
.
quantization
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
def
_get_expected_file
(
name
=
None
):
def
_get_expected_file
(
name
=
None
):
# Determine expected file based on environment
# Determine expected file based on environment
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
__file__
),
"expect"
)
expected_file_base
=
get_relative_path
(
os
.
path
.
realpath
(
__file__
),
"expect"
)
...
@@ -617,5 +624,49 @@ def test_video_model(model_name, dev):
...
@@ -617,5 +624,49 @@ def test_video_model(model_name, dev):
assert
out
.
shape
[
-
1
]
==
50
assert
out
.
shape
[
-
1
]
==
50
@
pytest
.
mark
.
skipif
(
not
(
'fbgemm'
in
torch
.
backends
.
quantized
.
supported_engines
and
'qnnpack'
in
torch
.
backends
.
quantized
.
supported_engines
),
reason
=
"This Pytorch Build has not been built with fbgemm and qnnpack"
)
@
pytest
.
mark
.
parametrize
(
'model_name'
,
get_available_quantizable_models
())
def
test_quantized_classification_model
(
model_name
):
defaults
=
{
'input_shape'
:
(
1
,
3
,
224
,
224
),
'pretrained'
:
False
,
'quantize'
:
True
,
}
kwargs
=
{
**
defaults
,
**
_model_params
.
get
(
model_name
,
{})}
input_shape
=
kwargs
.
pop
(
'input_shape'
)
# First check if quantize=True provides models that can run with input data
model
=
torchvision
.
models
.
quantization
.
__dict__
[
model_name
](
**
kwargs
)
x
=
torch
.
rand
(
input_shape
)
model
(
x
)
kwargs
[
'quantize'
]
=
False
for
eval_mode
in
[
True
,
False
]:
model
=
torchvision
.
models
.
quantization
.
__dict__
[
model_name
](
**
kwargs
)
if
eval_mode
:
model
.
eval
()
model
.
qconfig
=
torch
.
quantization
.
default_qconfig
else
:
model
.
train
()
model
.
qconfig
=
torch
.
quantization
.
default_qat_qconfig
model
.
fuse_model
()
if
eval_mode
:
torch
.
quantization
.
prepare
(
model
,
inplace
=
True
)
else
:
torch
.
quantization
.
prepare_qat
(
model
,
inplace
=
True
)
model
.
eval
()
torch
.
quantization
.
convert
(
model
,
inplace
=
True
)
try
:
torch
.
jit
.
script
(
model
)
except
Exception
as
e
:
tb
=
traceback
.
format_exc
()
raise
AssertionError
(
f
"model cannot be scripted. Traceback =
{
str
(
tb
)
}
"
)
from
e
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
test/test_quantized_models.py
deleted
100644 → 0
View file @
2d6931ab
import
torchvision
from
common_utils
import
TestCase
,
map_nested_tensor_object
from
collections
import
OrderedDict
from
itertools
import
product
import
torch
import
numpy
as
np
from
torchvision
import
models
import
unittest
import
traceback
import
random
def
set_rng_seed
(
seed
):
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
def
get_available_quantizable_models
():
# TODO add a registration mechanism to torchvision.models
return
[
k
for
k
,
v
in
models
.
quantization
.
__dict__
.
items
()
if
callable
(
v
)
and
k
[
0
].
lower
()
==
k
[
0
]
and
k
[
0
]
!=
"_"
]
# list of models that are not scriptable
scriptable_quantizable_models_blacklist
=
[]
@
unittest
.
skipUnless
(
'fbgemm'
in
torch
.
backends
.
quantized
.
supported_engines
and
'qnnpack'
in
torch
.
backends
.
quantized
.
supported_engines
,
"This Pytorch Build has not been built with fbgemm and qnnpack"
)
class
ModelTester
(
TestCase
):
def
check_quantized_model
(
self
,
model
,
input_shape
):
x
=
torch
.
rand
(
input_shape
)
model
(
x
)
return
def
check_script
(
self
,
model
,
name
):
if
name
in
scriptable_quantizable_models_blacklist
:
return
scriptable
=
True
msg
=
""
try
:
torch
.
jit
.
script
(
model
)
except
Exception
as
e
:
tb
=
traceback
.
format_exc
()
scriptable
=
False
msg
=
str
(
e
)
+
str
(
tb
)
self
.
assertTrue
(
scriptable
,
msg
)
def
_test_classification_model
(
self
,
name
,
input_shape
):
# First check if quantize=True provides models that can run with input data
model
=
torchvision
.
models
.
quantization
.
__dict__
[
name
](
pretrained
=
False
,
quantize
=
True
)
self
.
check_quantized_model
(
model
,
input_shape
)
for
eval_mode
in
[
True
,
False
]:
model
=
torchvision
.
models
.
quantization
.
__dict__
[
name
](
pretrained
=
False
,
quantize
=
False
)
if
eval_mode
:
model
.
eval
()
model
.
qconfig
=
torch
.
quantization
.
default_qconfig
else
:
model
.
train
()
model
.
qconfig
=
torch
.
quantization
.
default_qat_qconfig
model
.
fuse_model
()
if
eval_mode
:
torch
.
quantization
.
prepare
(
model
,
inplace
=
True
)
else
:
torch
.
quantization
.
prepare_qat
(
model
,
inplace
=
True
)
model
.
eval
()
torch
.
quantization
.
convert
(
model
,
inplace
=
True
)
self
.
check_script
(
model
,
name
)
for
model_name
in
get_available_quantizable_models
():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def
do_test
(
self
,
model_name
=
model_name
):
input_shape
=
(
1
,
3
,
224
,
224
)
if
model_name
in
[
'inception_v3'
]:
input_shape
=
(
1
,
3
,
299
,
299
)
self
.
_test_classification_model
(
model_name
,
input_shape
)
# inception_v3 was causing timeouts on circleci
# See https://github.com/pytorch/vision/issues/1857
if
model_name
not
in
[
'inception_v3'
]:
setattr
(
ModelTester
,
"test_"
+
model_name
,
do_test
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment