Commit 83d3770a authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Add test for num_class in test_model.py (#815)

* Add test for loading pretrained models

The update modifies the test to check whether the model can successfully load the pretrained weights. Will raise an error if the model parameters are incorrectly defined or named.

* Add test on 'num_class'

Passing num_class equal to a number other than 1000 helps in making the test more enforcing in nature.
parent 6334466e
...@@ -10,11 +10,12 @@ def get_available_models(): ...@@ -10,11 +10,12 @@ def get_available_models():
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def _test_model(self, name, input_shape): def _test_model(self, name, input_shape):
model = models.__dict__[name]() # passing num_class equal to a number other than 1000 helps in making the test more enforcing in nature
model = models.__dict__[name](num_classes=50)
model.eval() model.eval()
x = torch.rand(input_shape) x = torch.rand(input_shape)
out = model(x) out = model(x)
self.assertEqual(out.shape[-1], 1000) self.assertEqual(out.shape[-1], 50)
for model_name in get_available_models(): for model_name in get_available_models():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment