"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "b721550223bf843d6d651e1de39b3b1afa95f8d2"
Unverified Commit 59833e76 authored by Anirudh's avatar Anirudh Committed by GitHub
Browse files

port test_models_detection_utils.py to pytest (#4036)

parent e27b3925
...@@ -2,12 +2,13 @@ import copy ...@@ -2,12 +2,13 @@ import copy
import torch import torch
from torchvision.models.detection import _utils from torchvision.models.detection import _utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
import unittest import pytest
from torchvision.models.detection import backbone_utils from torchvision.models.detection import backbone_utils
from _assert_utils import assert_equal from _assert_utils import assert_equal
class Tester(unittest.TestCase): class TestModelsDetectionUtils:
def test_balanced_positive_negative_sampler(self): def test_balanced_positive_negative_sampler(self):
sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25) sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25)
# keep all 6 negatives first, then add 3 positives, last two are ignore # keep all 6 negatives first, then add 3 positives, last two are ignore
...@@ -16,39 +17,40 @@ class Tester(unittest.TestCase): ...@@ -16,39 +17,40 @@ class Tester(unittest.TestCase):
# we know the number of elements that should be sampled for the positive (1) # we know the number of elements that should be sampled for the positive (1)
# and the negative (3), and their location. Let's make sure that they are # and the negative (3), and their location. Let's make sure that they are
# there # there
self.assertEqual(pos[0].sum(), 1) assert pos[0].sum() == 1
self.assertEqual(pos[0][6:9].sum(), 1) assert pos[0][6:9].sum() == 1
self.assertEqual(neg[0].sum(), 3) assert neg[0].sum() == 3
self.assertEqual(neg[0][0:6].sum(), 3) assert neg[0][0:6].sum() == 3
def test_resnet_fpn_backbone_frozen_layers(self): @pytest.mark.parametrize('train_layers, exp_froz_params', [
(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)
])
def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params):
# we know how many initial layers and parameters of the network should # we know how many initial layers and parameters of the network should
# be frozen for each trainable_backbone_layers parameter value # be frozen for each trainable_backbone_layers parameter value
# i.e all 53 params are frozen if trainable_backbone_layers=0 # i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2 # ad first 24 params are frozen if trainable_backbone_layers=2
expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0} model = backbone_utils.resnet_fpn_backbone(
for train_layers, exp_froz_params in expected_frozen_params.items(): 'resnet50', pretrained=False, trainable_layers=train_layers)
model = backbone_utils.resnet_fpn_backbone( # boolean list that is true if the param at that index is frozen
'resnet50', pretrained=False, trainable_layers=train_layers) is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# boolean list that is true if the param at that index is frozen # check that expected initial number of layers are frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] assert all(is_frozen[:exp_froz_params])
# check that expected initial number of layers are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))
def test_validate_resnet_inputs_detection(self): def test_validate_resnet_inputs_detection(self):
# default number of backbone layers to train # default number of backbone layers to train
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
self.assertEqual(ret, 3) assert ret == 3
# can't go beyond 5 # can't go beyond 5
with self.assertRaises(AssertionError): with pytest.raises(AssertionError):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
# if not pretrained, should use all trainable layers and warn # if not pretrained, should use all trainable layers and warn
with self.assertWarns(UserWarning): with pytest.warns(UserWarning):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
self.assertEqual(ret, 5) assert ret == 5
def test_transform_copy_targets(self): def test_transform_copy_targets(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
...@@ -63,9 +65,9 @@ class Tester(unittest.TestCase): ...@@ -63,9 +65,9 @@ class Tester(unittest.TestCase):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)]
targets = [{'boxes': torch.rand(3, 4)}] targets = [{'boxes': torch.rand(3, 4)}]
with self.assertRaises(TypeError): with pytest.raises(TypeError):
out = transform(image, targets) # noqa: F841 out = transform(image, targets) # noqa: F841
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment