Unverified Commit ccb7f45a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add tests for VOC(Segmentation|Detection) and fix existing bugs (#3415)



* use common download utils in VOC and SBDataset

* add tests for VOC

* use common base class for VOC datasets

* remove old voc test and fake data generation
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 7b7cfdd4
...@@ -369,19 +369,6 @@ def svhn_root(): ...@@ -369,19 +369,6 @@ def svhn_root():
yield root yield root
@contextlib.contextmanager
def voc_root():
with get_tmp_dir() as tmp_dir:
voc_dir = os.path.join(tmp_dir, 'VOCdevkit',
'VOC2012', 'ImageSets', 'Main')
os.makedirs(voc_dir)
train_file = os.path.join(voc_dir, 'train.txt')
with open(train_file, 'w') as f:
f.write('test')
yield tmp_dir
@contextlib.contextmanager @contextlib.contextmanager
def ucf101_root(): def ucf101_root():
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
......
...@@ -11,7 +11,7 @@ import torchvision ...@@ -11,7 +11,7 @@ import torchvision
from torchvision.datasets import utils from torchvision.datasets import utils
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root cityscapes_root, svhn_root, ucf101_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
import itertools import itertools
...@@ -20,6 +20,7 @@ import pathlib ...@@ -20,6 +20,7 @@ import pathlib
import pickle import pickle
from torchvision import datasets from torchvision import datasets
import torch import torch
import shutil
try: try:
...@@ -259,38 +260,6 @@ class Tester(DatasetTestcase): ...@@ -259,38 +260,6 @@ class Tester(DatasetTestcase):
dataset = torchvision.datasets.SVHN(root, split="extra") dataset = torchvision.datasets.SVHN(root, split="extra")
self.generic_classification_dataset_test(dataset, num_images=2) self.generic_classification_dataset_test(dataset, num_images=2)
@mock.patch('torchvision.datasets.voc.download_extract')
def test_voc_parse_xml(self, mock_download_extract):
with voc_root() as root:
dataset = torchvision.datasets.VOCDetection(root)
single_object_xml = """<annotation>
<object>
<name>cat</name>
</object>
</annotation>"""
multiple_object_xml = """<annotation>
<object>
<name>cat</name>
</object>
<object>
<name>dog</name>
</object>
</annotation>"""
single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))
self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
self.assertEqual(multiple_object_parsed,
{'annotation': {
'object': [{
'name': 'cat'
}, {
'name': 'dog'
}]
}})
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable") @unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
def test_ucf101(self): def test_ucf101(self):
cached_meta_data = None cached_meta_data = None
...@@ -756,5 +725,119 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -756,5 +725,119 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"]) self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
CONFIGS = (
*datasets_utils.combinations_grid(
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
),
dict(year="2007", image_set="test"),
dict(year="2007-test", image_set="test"),
)
def inject_fake_data(self, tmpdir, config):
year, is_test_set = (
("2007", True)
if config["year"] == "2007-test" or config["image_set"] == "test"
else (config["year"], False)
)
image_set = config["image_set"]
base_dir = pathlib.Path(tmpdir)
if year == "2011":
base_dir /= "TrainVal"
base_dir = base_dir / "VOCdevkit" / f"VOC{year}"
os.makedirs(base_dir)
num_images, num_images_per_image_set = self._create_image_set_files(base_dir, "ImageSets", is_test_set)
datasets_utils.create_image_folder(base_dir, "JPEGImages", lambda idx: f"{idx:06d}.jpg", num_images)
datasets_utils.create_image_folder(base_dir, "SegmentationClass", lambda idx: f"{idx:06d}.png", num_images)
annotation = self._create_annotation_files(base_dir, "Annotations", num_images)
return dict(num_examples=num_images_per_image_set[image_set], annotation=annotation)
def _create_image_set_files(self, root, name, is_test_set):
root = pathlib.Path(root) / name
src = pathlib.Path(root) / "Main"
os.makedirs(src, exist_ok=True)
idcs = dict(train=(0, 1, 2), val=(3, 4), test=(5,))
idcs["trainval"] = (*idcs["train"], *idcs["val"])
for image_set in ("test",) if is_test_set else ("train", "val", "trainval"):
self._create_image_set_file(src, image_set, idcs[image_set])
shutil.copytree(src, root / "Segmentation")
num_images = max(itertools.chain(*idcs.values())) + 1
num_images_per_image_set = dict([(image_set, len(idcs_)) for image_set, idcs_ in idcs.items()])
return num_images, num_images_per_image_set
def _create_image_set_file(self, root, image_set, idcs):
with open(pathlib.Path(root) / f"{image_set}.txt", "w") as fh:
fh.writelines([f"{idx:06d}\n" for idx in idcs])
def _create_annotation_files(self, root, name, num_images):
root = pathlib.Path(root) / name
os.makedirs(root)
for idx in range(num_images):
annotation = self._create_annotation_file(root, f"{idx:06d}.xml")
return annotation
def _create_annotation_file(self, root, name):
def add_child(parent, name, text=None):
child = ET.SubElement(parent, name)
child.text = text
return child
def add_name(obj, name="dog"):
add_child(obj, "name", name)
return name
def add_bndbox(obj, bndbox=None):
if bndbox is None:
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
obj = add_child(obj, "bndbox")
for name, text in bndbox.items():
add_child(obj, name, text)
return bndbox
annotation = ET.Element("annotation")
obj = add_child(annotation, "object")
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
with open(pathlib.Path(root) / name, "wb") as fh:
fh.write(ET.tostring(annotation))
return data
class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection
FEATURE_TYPES = (PIL.Image.Image, dict)
def test_annotations(self):
with self.create_dataset() as (dataset, info):
_, target = dataset[0]
self.assertIn("annotation", target)
annotation = target["annotation"]
self.assertIn("object", annotation)
objects = annotation["object"]
self.assertEqual(len(objects), 1)
object = objects[0]
self.assertEqual(object, info["annotation"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Tuple ...@@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .utils import download_url, verify_str_arg from .utils import download_url, verify_str_arg, download_and_extract_archive
from .voc import download_extract
class SBDataset(VisionDataset): class SBDataset(VisionDataset):
...@@ -77,7 +76,7 @@ class SBDataset(VisionDataset): ...@@ -77,7 +76,7 @@ class SBDataset(VisionDataset):
mask_dir = os.path.join(sbd_root, 'cls') mask_dir = os.path.join(sbd_root, 'cls')
if download: if download:
download_extract(self.url, self.root, self.filename, self.md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset") extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
for f in ["cls", "img", "inst", "train.txt", "val.txt"]: for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
old_path = os.path.join(extracted_ds_root, f) old_path = os.path.join(extracted_ds_root, f)
......
...@@ -4,8 +4,9 @@ import collections ...@@ -4,8 +4,9 @@ import collections
from .vision import VisionDataset from .vision import VisionDataset
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from PIL import Image from PIL import Image
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple, List
from .utils import download_url, verify_str_arg from .utils import download_and_extract_archive, verify_str_arg
import warnings
DATASET_YEAR_DICT = { DATASET_YEAR_DICT = {
'2012': { '2012': {
...@@ -53,23 +54,10 @@ DATASET_YEAR_DICT = { ...@@ -53,23 +54,10 @@ DATASET_YEAR_DICT = {
} }
class VOCSegmentation(VisionDataset): class _VOCBase(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. _SPLITS_DIR: str
_TARGET_DIR: str
Args: _TARGET_FILE_EXT: str
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__( def __init__(
self, self,
...@@ -81,39 +69,86 @@ class VOCSegmentation(VisionDataset): ...@@ -81,39 +69,86 @@ class VOCSegmentation(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
): ):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform) super().__init__(root, transforms, transform, target_transform)
self.year = year
if year == "2007" and image_set == "test":
year = "2007-test"
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
valid_sets = ["train", "trainval", "val"]
if year == "2007-test": if year == "2007-test":
valid_sets.append("test") if image_set == "test":
self.image_set = verify_str_arg(image_set, "image_set", valid_sets) warnings.warn(
base_dir = DATASET_YEAR_DICT[year]['base_dir'] "Acessing the test image set of the year 2007 with year='2007-test' is deprecated. "
"Please use the combination year='2007' and image_set='test' instead."
)
year = "2007"
else:
raise ValueError(
"In the test image set of the year 2007 only image_set='test' is allowed. "
"For all other image sets use year='2007' instead."
)
self.year = year
valid_image_sets = ["train", "trainval", "val"]
if year == "2007":
valid_image_sets.append("test")
key = "2007-test"
else:
key = year
self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
dataset_year_dict = DATASET_YEAR_DICT[key]
self.url = dataset_year_dict["url"]
self.filename = dataset_year_dict["filename"]
self.md5 = dataset_year_dict["md5"]
base_dir = dataset_year_dict["base_dir"]
voc_root = os.path.join(self.root, base_dir) voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClass')
if download: if download:
download_extract(self.url, self.root, self.filename, self.md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
if not os.path.isdir(voc_root): if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
' You can use download=True to download it')
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f: with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()] file_names = [x.strip() for x in f.readlines()]
image_dir = os.path.join(voc_root, "JPEGImages")
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks)) target_dir = os.path.join(voc_root, self._TARGET_DIR)
self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
assert len(self.images) == len(self.targets)
def __len__(self) -> int:
return len(self.images)
class VOCSegmentation(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Segmentation"
_TARGET_DIR = "SegmentationClass"
_TARGET_FILE_EXT = ".png"
@property
def masks(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
...@@ -123,7 +158,7 @@ class VOCSegmentation(VisionDataset): ...@@ -123,7 +158,7 @@ class VOCSegmentation(VisionDataset):
Returns: Returns:
tuple: (image, target) where target is the image segmentation. tuple: (image, target) where target is the image segmentation.
""" """
img = Image.open(self.images[index]).convert('RGB') img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index]) target = Image.open(self.masks[index])
if self.transforms is not None: if self.transforms is not None:
...@@ -131,17 +166,15 @@ class VOCSegmentation(VisionDataset): ...@@ -131,17 +166,15 @@ class VOCSegmentation(VisionDataset):
return img, target return img, target
def __len__(self) -> int:
return len(self.images)
class VOCDetection(VisionDataset): class VOCDetection(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args: Args:
root (string): Root directory of the VOC Dataset. root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012. year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
...@@ -154,50 +187,13 @@ class VOCDetection(VisionDataset): ...@@ -154,50 +187,13 @@ class VOCDetection(VisionDataset):
and returns a transformed version. and returns a transformed version.
""" """
def __init__( _SPLITS_DIR = "Main"
self, _TARGET_DIR = "Annotations"
root: str, _TARGET_FILE_EXT = ".xml"
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
self.year = year
if year == "2007" and image_set == "test":
year = "2007-test"
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
valid_sets = ["train", "trainval", "val"]
if year == "2007-test":
valid_sets.append("test")
self.image_set = verify_str_arg(image_set, "image_set", valid_sets)
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
annotation_dir = os.path.join(voc_root, 'Annotations')
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
splits_dir = os.path.join(voc_root, 'ImageSets/Main')
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') @property
def annotations(self) -> List[str]:
with open(os.path.join(split_f), "r") as f: return self.targets
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
assert (len(self.images) == len(self.annotations))
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
...@@ -207,18 +203,14 @@ class VOCDetection(VisionDataset): ...@@ -207,18 +203,14 @@ class VOCDetection(VisionDataset):
Returns: Returns:
tuple: (image, target) where target is a dictionary of the XML tree. tuple: (image, target) where target is a dictionary of the XML tree.
""" """
img = Image.open(self.images[index]).convert('RGB') img = Image.open(self.images[index]).convert("RGB")
target = self.parse_voc_xml( target = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
ET.parse(self.annotations[index]).getroot())
if self.transforms is not None: if self.transforms is not None:
img, target = self.transforms(img, target) img, target = self.transforms(img, target)
return img, target return img, target
def __len__(self) -> int:
return len(self.images)
def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]: def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
voc_dict: Dict[str, Any] = {} voc_dict: Dict[str, Any] = {}
children = list(node) children = list(node)
...@@ -227,21 +219,11 @@ class VOCDetection(VisionDataset): ...@@ -227,21 +219,11 @@ class VOCDetection(VisionDataset):
for dc in map(self.parse_voc_xml, children): for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items(): for ind, v in dc.items():
def_dic[ind].append(v) def_dic[ind].append(v)
if node.tag == 'annotation': if node.tag == "annotation":
def_dic['object'] = [def_dic['object']] def_dic["object"] = [def_dic["object"]]
voc_dict = { voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
node.tag:
{ind: v[0] if len(v) == 1 else v
for ind, v in def_dic.items()}
}
if node.text: if node.text:
text = node.text.strip() text = node.text.strip()
if not children: if not children:
voc_dict[node.tag] = text voc_dict[node.tag] = text
return voc_dict return voc_dict
def download_extract(url: str, root: str, filename: str, md5: str) -> None:
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment