Unverified Commit 7326623a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Port test_quantized_models.py to pytest (#4034)

parent 2d6931ab
......@@ -8,9 +8,11 @@ import functools
import operator
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import pytest
import warnings
import traceback
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1'
......@@ -36,6 +38,11 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def get_available_quantizable_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def _get_expected_file(name=None):
# Determine expected file based on environment
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
......@@ -617,5 +624,49 @@ def test_video_model(model_name, dev):
assert out.shape[-1] == 50
@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines),
reason="This Pytorch Build has not been built with fbgemm and qnnpack")
@pytest.mark.parametrize('model_name', get_available_quantizable_models())
def test_quantized_classification_model(model_name):
defaults = {
'input_shape': (1, 3, 224, 224),
'pretrained': False,
'quantize': True,
}
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop('input_shape')
# First check if quantize=True provides models that can run with input data
model = torchvision.models.quantization.__dict__[model_name](**kwargs)
x = torch.rand(input_shape)
model(x)
kwargs['quantize'] = False
for eval_mode in [True, False]:
model = torchvision.models.quantization.__dict__[model_name](**kwargs)
if eval_mode:
model.eval()
model.qconfig = torch.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.fuse_model()
if eval_mode:
torch.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
model.eval()
torch.quantization.convert(model, inplace=True)
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
if __name__ == '__main__':
pytest.main([__file__])
import torchvision
from common_utils import TestCase, map_nested_tensor_object
from collections import OrderedDict
from itertools import product
import torch
import numpy as np
from torchvision import models
import unittest
import traceback
import random
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_quantizable_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# list of models that are not scriptable
scriptable_quantizable_models_blacklist = []
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines,
"This Pytorch Build has not been built with fbgemm and qnnpack")
class ModelTester(TestCase):
def check_quantized_model(self, model, input_shape):
x = torch.rand(input_shape)
model(x)
return
def check_script(self, model, name):
if name in scriptable_quantizable_models_blacklist:
return
scriptable = True
msg = ""
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertTrue(scriptable, msg)
def _test_classification_model(self, name, input_shape):
# First check if quantize=True provides models that can run with input data
model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=True)
self.check_quantized_model(model, input_shape)
for eval_mode in [True, False]:
model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=False)
if eval_mode:
model.eval()
model.qconfig = torch.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.fuse_model()
if eval_mode:
torch.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
model.eval()
torch.quantization.convert(model, inplace=True)
self.check_script(model, name)
for model_name in get_available_quantizable_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape)
# inception_v3 was causing timeouts on circleci
# See https://github.com/pytorch/vision/issues/1857
if model_name not in ['inception_v3']:
setattr(ModelTester, "test_" + model_name, do_test)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment