"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "c03bba0008146fccfec73c74cffe923299facea8"
Unverified Commit c8345212 authored by João Fernandes's avatar João Fernandes Committed by GitHub
Browse files

Small indentation fix (#1831)

* Force object annotiation to be an array

* Remove unecessary parentheses

* Change object check

* Remove check for list

* Add test coverage to xml parsing

* Tidy up whitespace

* Fix indentation
parent a7914797
...@@ -40,10 +40,12 @@ class Tester(unittest.TestCase): ...@@ -40,10 +40,12 @@ class Tester(unittest.TestCase):
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b']) classes = sorted(['a', 'b'])
class_a_image_files = [os.path.join(root, 'a', file) class_a_image_files = [
for file in ('a1.png', 'a2.png', 'a3.png')] os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
class_b_image_files = [os.path.join(root, 'b', file) ]
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')] class_b_image_files = [
os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
]
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x) dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)
# test if all classes are present # test if all classes are present
...@@ -66,8 +68,8 @@ class Tester(unittest.TestCase): ...@@ -66,8 +68,8 @@ class Tester(unittest.TestCase):
self.assertEqual(imgs, outputs) self.assertEqual(imgs, outputs)
# redo all tests with specified valid image files # redo all tests with specified valid image files
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x, dataset = torchvision.datasets.ImageFolder(
is_valid_file=lambda x: '3' in x) root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
self.assertEqual(classes, sorted(dataset.classes)) self.assertEqual(classes, sorted(dataset.classes))
class_a_idx = dataset.class_to_idx['a'] class_a_idx = dataset.class_to_idx['a']
...@@ -164,18 +166,18 @@ class Tester(unittest.TestCase): ...@@ -164,18 +166,18 @@ class Tester(unittest.TestCase):
for split in splits: for split in splits:
for target_type in ['semantic', 'instance']: for target_type in ['semantic', 'instance']:
dataset = torchvision.datasets.Cityscapes(root, split=split, dataset = torchvision.datasets.Cityscapes(
target_type=target_type, mode=mode) root, split=split, target_type=target_type, mode=mode)
self.generic_segmentation_dataset_test(dataset, num_images=2) self.generic_segmentation_dataset_test(dataset, num_images=2)
color_dataset = torchvision.datasets.Cityscapes(root, split=split, color_dataset = torchvision.datasets.Cityscapes(
target_type='color', mode=mode) root, split=split, target_type='color', mode=mode)
color_img, color_target = color_dataset[0] color_img, color_target = color_dataset[0]
self.assertTrue(isinstance(color_img, PIL.Image.Image)) self.assertTrue(isinstance(color_img, PIL.Image.Image))
self.assertTrue(np.array(color_target).shape[2] == 4) self.assertTrue(np.array(color_target).shape[2] == 4)
polygon_dataset = torchvision.datasets.Cityscapes(root, split=split, polygon_dataset = torchvision.datasets.Cityscapes(
target_type='polygon', mode=mode) root, split=split, target_type='polygon', mode=mode)
polygon_img, polygon_target = polygon_dataset[0] polygon_img, polygon_target = polygon_dataset[0]
self.assertTrue(isinstance(polygon_img, PIL.Image.Image)) self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
self.assertTrue(isinstance(polygon_target, dict)) self.assertTrue(isinstance(polygon_target, dict))
...@@ -184,9 +186,8 @@ class Tester(unittest.TestCase): ...@@ -184,9 +186,8 @@ class Tester(unittest.TestCase):
# Test multiple target types # Test multiple target types
targets_combo = ['semantic', 'polygon', 'color'] targets_combo = ['semantic', 'polygon', 'color']
multiple_types_dataset = torchvision.datasets.Cityscapes(root, split=split, multiple_types_dataset = torchvision.datasets.Cityscapes(
target_type=targets_combo, root, split=split, target_type=targets_combo, mode=mode)
mode=mode)
output = multiple_types_dataset[0] output = multiple_types_dataset[0]
self.assertTrue(isinstance(output, tuple)) self.assertTrue(isinstance(output, tuple))
self.assertTrue(len(output) == 2) self.assertTrue(len(output) == 2)
...@@ -229,13 +230,19 @@ class Tester(unittest.TestCase): ...@@ -229,13 +230,19 @@ class Tester(unittest.TestCase):
<name>dog</name> <name>dog</name>
</object> </object>
</annotation>""" </annotation>"""
single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml
)) single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml)) multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))
self.assertEqual(single_object_parsed, {'annotation': {'object':[{'name': 'cat'}]}}) self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
self.assertEqual(multiple_object_parsed, {'annotation': self.assertEqual(multiple_object_parsed,
{'object':[{'name': 'cat'}, {'name': 'dog'}]}}) {'annotation': {
'object': [{
'name': 'cat'
}, {
'name': 'dog'
}]
}})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment