Unverified Commit 57e87769 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add basic model testing. (#811)

* Add basic model testing.

Also fixes flaky test

* Fix flake8
parent c88d7fb5
import torch
from torchvision import models
import unittest
def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0]]
class Tester(unittest.TestCase):
def _test_model(self, name, input_shape):
model = models.__dict__[name]()
model.eval()
x = torch.rand(input_shape)
out = model(x)
self.assertEqual(out.shape[-1], 1000)
for model_name in get_available_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_model(model_name, input_shape)
setattr(Tester, "test_" + model_name, do_test)
if __name__ == '__main__':
unittest.main()
......@@ -142,7 +142,7 @@ class Tester(unittest.TestCase):
for i in range(10):
scale_min = round(random.random(), 2)
scale_range = (scale_min, scale_min + round(random.random(), 2))
aspect_min = round(random.random(), 2)
aspect_min = max(round(random.random(), 2), epsilon)
aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
_, _, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment