Unverified Commit 4521f6d1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor & enable JIT tests in all models and add warnings if skipped (#3033)

* Enable jit tests in all models and add warning if checkModule() tests are skipped.

* Turning on JIT tests on CI.

* Fixing broken unit-tests.

* Refactoring and cleaning up duplicate code.
parent a51c49e4
......@@ -5,5 +5,6 @@ set -e
eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
......@@ -5,5 +5,6 @@ set -e
eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
conda activate ./env
export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
......@@ -7,7 +7,7 @@ import argparse
import sys
import io
import torch
import errno
import warnings
import __main__
from numbers import Number
......@@ -265,14 +265,21 @@ class TestCase(unittest.TestCase):
else:
super(TestCase, self).assertEqual(x, y, message)
def checkModule(self, nn_module, args, unwrapper=None, skip=False):
def check_jit_scriptable(self, nn_module, args, unwrapper=None, skip=False):
"""
Check that a nn.Module's results in TorchScript match eager and that it
can be exported
"""
if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests
return
msg = "The check_jit_scriptable test for {} was skipped. " \
"This test checks if the module's results in TorchScript " \
"match eager and that it can be exported. To run these " \
"tests make sure you set the environment variable " \
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
"manually skipped.".format(nn_module.__class__.__name__)
warnings.warn(msg, RuntimeWarning)
return None
sm = torch.jit.script(nn_module)
......@@ -284,7 +291,7 @@ class TestCase(unittest.TestCase):
if unwrapper:
script_out = unwrapper(script_out)
self.assertEqual(eager_out, script_out)
self.assertEqual(eager_out, script_out, prec=1e-4)
self.assertExportImportModule(sm, args)
return sm
......
......@@ -38,44 +38,16 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# models that are in torch hub, as well as r3d_18. we tried testing all models
# but the test was too slow. not included are detection models, because
# they are not yet supported in JIT.
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
script_test_models = {
'deeplabv3_resnet50': {},
'deeplabv3_resnet101': {},
'mobilenet_v2': {},
'resnext50_32x4d': {},
'fcn_resnet50': {},
'fcn_resnet101': {},
'googlenet': {
'unwrapper': lambda x: x.logits
},
'densenet121': {},
'resnet18': {},
'alexnet': {},
'shufflenet_v2_x1_0': {},
'squeezenet1_0': {},
'vgg11': {},
'inception_v3': {
'unwrapper': lambda x: x.logits
},
'r3d_18': {},
"fasterrcnn_resnet50_fpn": {
'unwrapper': lambda x: x[1]
},
"maskrcnn_resnet50_fpn": {
'unwrapper': lambda x: x[1]
},
"keypointrcnn_resnet50_fpn": {
'unwrapper': lambda x: x[1]
},
"retinanet_resnet50_fpn": {
'unwrapper': lambda x: x[1]
}
script_model_unwrapper = {
'googlenet': lambda x: x.logits,
'inception_v3': lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
}
......@@ -97,12 +69,6 @@ autocast_flaky_numerics = (
class ModelTester(TestCase):
def checkModule(self, model, name, args):
if name not in script_test_models:
return
unwrapper = script_test_models[name].get('unwrapper', None)
return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False)
def _test_classification_model(self, name, input_shape, dev):
set_rng_seed(0)
# passing num_class equal to a number other than 1000 helps in making the test
......@@ -114,7 +80,7 @@ class ModelTester(TestCase):
out = model(x)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertEqual(out.shape[-1], 50)
self.checkModule(model, name, (x,))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda":
with torch.cuda.amp.autocast():
......@@ -134,7 +100,7 @@ class ModelTester(TestCase):
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.checkModule(model, name, (x,))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda":
with torch.cuda.amp.autocast():
......@@ -209,18 +175,7 @@ class ModelTester(TestCase):
return True # Full validation performed
full_validation = check_out(out)
scripted_model = torch.jit.script(model)
scripted_model.eval()
scripted_out = scripted_model(model_input)[1]
self.assertEqual(scripted_out[0]["boxes"], out[0]["boxes"])
self.assertEqual(scripted_out[0]["scores"], out[0]["scores"])
# labels currently float in script: need to investigate (though same result)
self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"])
# don't check script because we are compiling it here:
# TODO: refactor tests
# self.check_script(model, name)
self.checkModule(model, name, ([x],))
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda":
with torch.cuda.amp.autocast():
......@@ -270,7 +225,7 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.checkModule(model, name, (x,))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
self.assertEqual(out.shape[-1], 50)
if dev == "cuda":
......@@ -345,11 +300,13 @@ class ModelTester(TestCase):
kwargs['transform_input'] = True
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
name = "inception_v3"
model = models.Inception3(**kwargs)
model.aux_logits = False
model.AuxLogits = None
m = torch.jit.script(model.eval())
self.checkModule(m, "inception_v3", torch.rand(1, 3, 299, 299))
model = model.eval()
x = torch.rand(1, 3, 299, 299)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
......@@ -371,12 +328,14 @@ class ModelTester(TestCase):
kwargs['transform_input'] = True
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
name = "googlenet"
model = models.GoogLeNet(**kwargs)
model.aux_logits = False
model.aux1 = None
model.aux2 = None
m = torch.jit.script(model.eval())
self.checkModule(m, "googlenet", torch.rand(1, 3, 224, 224))
model = model.eval()
x = torch.rand(1, 3, 224, 224)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment