Unverified Commit 7326623a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Port test_quantized_models.py to pytest (#4034)

parent 2d6931ab
...@@ -8,9 +8,11 @@ import functools ...@@ -8,9 +8,11 @@ import functools
import operator import operator
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision
from torchvision import models from torchvision import models
import pytest import pytest
import warnings import warnings
import traceback
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1'
...@@ -36,6 +38,11 @@ def get_available_video_models(): ...@@ -36,6 +38,11 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def get_available_quantizable_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def _get_expected_file(name=None): def _get_expected_file(name=None):
# Determine expected file based on environment # Determine expected file based on environment
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect") expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
...@@ -617,5 +624,49 @@ def test_video_model(model_name, dev): ...@@ -617,5 +624,49 @@ def test_video_model(model_name, dev):
assert out.shape[-1] == 50 assert out.shape[-1] == 50
@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines),
reason="This Pytorch Build has not been built with fbgemm and qnnpack")
@pytest.mark.parametrize('model_name', get_available_quantizable_models())
def test_quantized_classification_model(model_name):
defaults = {
'input_shape': (1, 3, 224, 224),
'pretrained': False,
'quantize': True,
}
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop('input_shape')
# First check if quantize=True provides models that can run with input data
model = torchvision.models.quantization.__dict__[model_name](**kwargs)
x = torch.rand(input_shape)
model(x)
kwargs['quantize'] = False
for eval_mode in [True, False]:
model = torchvision.models.quantization.__dict__[model_name](**kwargs)
if eval_mode:
model.eval()
model.qconfig = torch.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.fuse_model()
if eval_mode:
torch.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
model.eval()
torch.quantization.convert(model, inplace=True)
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
if __name__ == '__main__': if __name__ == '__main__':
pytest.main([__file__]) pytest.main([__file__])
import torchvision
from common_utils import TestCase, map_nested_tensor_object
from collections import OrderedDict
from itertools import product
import torch
import numpy as np
from torchvision import models
import unittest
import traceback
import random
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_quantizable_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# list of models that are not scriptable
scriptable_quantizable_models_blacklist = []
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines,
"This Pytorch Build has not been built with fbgemm and qnnpack")
class ModelTester(TestCase):
def check_quantized_model(self, model, input_shape):
x = torch.rand(input_shape)
model(x)
return
def check_script(self, model, name):
if name in scriptable_quantizable_models_blacklist:
return
scriptable = True
msg = ""
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertTrue(scriptable, msg)
def _test_classification_model(self, name, input_shape):
# First check if quantize=True provides models that can run with input data
model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=True)
self.check_quantized_model(model, input_shape)
for eval_mode in [True, False]:
model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=False)
if eval_mode:
model.eval()
model.qconfig = torch.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.fuse_model()
if eval_mode:
torch.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
model.eval()
torch.quantization.convert(model, inplace=True)
self.check_script(model, name)
for model_name in get_available_quantizable_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape)
# inception_v3 was causing timeouts on circleci
# See https://github.com/pytorch/vision/issues/1857
if model_name not in ['inception_v3']:
setattr(ModelTester, "test_" + model_name, do_test)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment