Commit 8878068e authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Test that torchhub models are scriptable (#1242)

* test that torchhub models are scriptable

* fix lint
parent 4f8b8ff1
...@@ -25,11 +25,39 @@ def get_available_video_models(): ...@@ -25,11 +25,39 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# model_name, expected to script without error
torchub_models = {
"deeplabv3_resnet101": False,
"mobilenet_v2": True,
"resnext50_32x4d": False,
"fcn_resnet101": False,
"googlenet": False,
"densenet121": False,
"resnet18": False,
"alexnet": True,
"shufflenet_v2_x1_0": False,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
}
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def check_script(self, model, name):
if name not in torchub_models:
return
scriptable = True
try:
torch.jit.script(model)
except Exception:
scriptable = False
self.assertEqual(torchub_models[name], scriptable)
def _test_classification_model(self, name, input_shape): def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test # passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature # more enforcing in nature
model = models.__dict__[name](num_classes=50) model = models.__dict__[name](num_classes=50)
self.check_script(model, name)
model.eval() model.eval()
x = torch.rand(input_shape) x = torch.rand(input_shape)
out = model(x) out = model(x)
...@@ -39,6 +67,7 @@ class Tester(unittest.TestCase): ...@@ -39,6 +67,7 @@ class Tester(unittest.TestCase):
# passing num_class equal to a number other than 1000 helps in making the test # passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature # more enforcing in nature
model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False) model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
self.check_script(model, name)
model.eval() model.eval()
input_shape = (1, 3, 300, 300) input_shape = (1, 3, 300, 300)
x = torch.rand(input_shape) x = torch.rand(input_shape)
...@@ -47,6 +76,7 @@ class Tester(unittest.TestCase): ...@@ -47,6 +76,7 @@ class Tester(unittest.TestCase):
def _test_detection_model(self, name): def _test_detection_model(self, name):
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
self.check_script(model, name)
model.eval() model.eval()
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
x = torch.rand(input_shape) x = torch.rand(input_shape)
...@@ -64,6 +94,7 @@ class Tester(unittest.TestCase): ...@@ -64,6 +94,7 @@ class Tester(unittest.TestCase):
input_shape = (1, 3, 4, 112, 112) input_shape = (1, 3, 4, 112, 112)
# test both basicblock and Bottleneck # test both basicblock and Bottleneck
model = models.video.__dict__[name](num_classes=50) model = models.video.__dict__[name](num_classes=50)
self.check_script(model, name)
x = torch.rand(input_shape) x = torch.rand(input_shape)
out = model(x) out = model(x)
self.assertEqual(out.shape[-1], 50) self.assertEqual(out.shape[-1], 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment