Unverified Commit cf78a29b authored by João Fernandes's avatar João Fernandes Committed by GitHub
Browse files

Force object annotiation to be a list (#1790)

* Force object annotiation to be an array

* Remove unecessary parentheses

* Change object check

* Remove check for list

* Add test coverage to xml parsing

* Tidy up whitespace
parent 9e8258d1
......@@ -258,3 +258,15 @@ def svhn_root():
_make_mat(os.path.join(root, "extra_32x32.mat"))
yield root
@contextlib.contextmanager
def voc_root():
with get_tmp_dir() as tmp_dir:
voc_dir = os.path.join(tmp_dir, 'VOCdevkit',
'VOC2012','ImageSets','Main')
os.makedirs(voc_dir)
train_file = os.path.join(voc_dir,'train.txt')
with open(train_file, 'w') as f:
f.write('test')
yield tmp_dir
......@@ -9,7 +9,8 @@ from torch._utils_internal import get_file_path_2
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root
cityscapes_root, svhn_root, voc_root
import xml.etree.ElementTree as ET
try:
......@@ -210,6 +211,32 @@ class Tester(unittest.TestCase):
dataset = torchvision.datasets.SVHN(root, split="extra")
self.generic_classification_dataset_test(dataset, num_images=2)
@mock.patch('torchvision.datasets.voc.download_extract')
def test_voc_parse_xml(self, mock_download_extract):
with voc_root() as root:
dataset = torchvision.datasets.VOCDetection(root)
single_object_xml = """<annotation>
<object>
<name>cat</name>
</object>
</annotation>"""
multiple_object_xml = """<annotation>
<object>
<name>cat</name>
</object>
<object>
<name>dog</name>
</object>
</annotation>"""
single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml
))
multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))
self.assertEqual(single_object_parsed, {'annotation': {'object':[{'name': 'cat'}]}})
self.assertEqual(multiple_object_parsed, {'annotation':
{'object':[{'name': 'cat'}, {'name': 'dog'}]}})
if __name__ == '__main__':
unittest.main()
......@@ -218,6 +218,8 @@ class VOCDetection(VisionDataset):
for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
if node.tag == 'annotation':
def_dic['object'] = [def_dic['object']]
voc_dict = {
node.tag:
{ind: v[0] if len(v) == 1 else v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment