Unverified Commit 9c799349 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Simplify the setup for AnchorGenerator in unittest (#3023)

parent 8c281757
from collections import OrderedDict
import torch
import unittest
from common_utils import TestCase
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.image_list import ImageList
class Tester(unittest.TestCase):
class Tester(TestCase):
def test_incorrect_anchors(self):
incorrect_sizes = ((2, 4, 8), (32, 8), )
incorrect_aspects = (0.5, 1.0)
......@@ -16,26 +16,20 @@ class Tester(unittest.TestCase):
self.assertRaises(ValueError, anc, image_list, feature_maps)
def _init_test_anchor_generator(self):
anchor_sizes = tuple((x,) for x in [32, 64, 128])
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_sizes = ((10,),)
aspect_ratios = ((1,),)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
return anchor_generator
def get_features(self, images):
s0, s1 = images.shape[-2:]
features = [
('0', torch.rand(2, 8, s0 // 4, s1 // 4)),
('1', torch.rand(2, 16, s0 // 8, s1 // 8)),
('2', torch.rand(2, 32, s0 // 16, s1 // 16)),
]
features = OrderedDict(features)
features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
return features
def test_anchor_generator(self):
images = torch.randn(2, 3, 16, 32)
images = torch.randn(2, 3, 15, 15)
features = self.get_features(images)
features = list(features.values())
image_shapes = [i.shape[-2:] for i in images]
images = ImageList(images, image_shapes)
......@@ -43,13 +37,25 @@ class Tester(unittest.TestCase):
model.eval()
anchors = model(images, features)
# Compute target anchors numbers
# Estimate the number of target anchors
grid_sizes = [f.shape[-2:] for f in features]
num_anchors_estimated = 0
for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
self.assertEqual(num_anchors_estimated, 126)
anchors_output = torch.tensor([[-5., -5., 5., 5.],
[0., -5., 10., 5.],
[5., -5., 15., 5.],
[-5., 0., 5., 10.],
[0., 0., 10., 10.],
[5., 0., 15., 10.],
[-5., 5., 5., 15.],
[0., 5., 10., 15.],
[5., 5., 15., 15.]])
self.assertEqual(num_anchors_estimated, 9)
self.assertEqual(len(anchors), 2)
self.assertEqual(tuple(anchors[0].shape), (num_anchors_estimated, 4))
self.assertEqual(tuple(anchors[1].shape), (num_anchors_estimated, 4))
self.assertEqual(tuple(anchors[0].shape), (9, 4))
self.assertEqual(tuple(anchors[1].shape), (9, 4))
self.assertEqual(anchors[0], anchors_output)
self.assertEqual(anchors[1], anchors_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment