Unverified Commit ccb7f45a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add tests for VOC(Segmentation|Detection) and fix existing bugs (#3415)



* use common download utils in VOC and SBDataset

* add tests for VOC

* use common base class for VOC datasets

* remove old voc test and fake data generation
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 7b7cfdd4
......@@ -369,19 +369,6 @@ def svhn_root():
yield root
@contextlib.contextmanager
def voc_root():
with get_tmp_dir() as tmp_dir:
voc_dir = os.path.join(tmp_dir, 'VOCdevkit',
'VOC2012', 'ImageSets', 'Main')
os.makedirs(voc_dir)
train_file = os.path.join(voc_dir, 'train.txt')
with open(train_file, 'w') as f:
f.write('test')
yield tmp_dir
@contextlib.contextmanager
def ucf101_root():
with get_tmp_dir() as tmp_dir:
......
......@@ -11,7 +11,7 @@ import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
cityscapes_root, svhn_root, ucf101_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
......@@ -20,6 +20,7 @@ import pathlib
import pickle
from torchvision import datasets
import torch
import shutil
try:
......@@ -259,38 +260,6 @@ class Tester(DatasetTestcase):
dataset = torchvision.datasets.SVHN(root, split="extra")
self.generic_classification_dataset_test(dataset, num_images=2)
@mock.patch('torchvision.datasets.voc.download_extract')
def test_voc_parse_xml(self, mock_download_extract):
with voc_root() as root:
dataset = torchvision.datasets.VOCDetection(root)
single_object_xml = """<annotation>
<object>
<name>cat</name>
</object>
</annotation>"""
multiple_object_xml = """<annotation>
<object>
<name>cat</name>
</object>
<object>
<name>dog</name>
</object>
</annotation>"""
single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))
self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
self.assertEqual(multiple_object_parsed,
{'annotation': {
'object': [{
'name': 'cat'
}, {
'name': 'dog'
}]
}})
@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
def test_ucf101(self):
cached_meta_data = None
......@@ -756,5 +725,119 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
CONFIGS = (
*datasets_utils.combinations_grid(
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
),
dict(year="2007", image_set="test"),
dict(year="2007-test", image_set="test"),
)
def inject_fake_data(self, tmpdir, config):
year, is_test_set = (
("2007", True)
if config["year"] == "2007-test" or config["image_set"] == "test"
else (config["year"], False)
)
image_set = config["image_set"]
base_dir = pathlib.Path(tmpdir)
if year == "2011":
base_dir /= "TrainVal"
base_dir = base_dir / "VOCdevkit" / f"VOC{year}"
os.makedirs(base_dir)
num_images, num_images_per_image_set = self._create_image_set_files(base_dir, "ImageSets", is_test_set)
datasets_utils.create_image_folder(base_dir, "JPEGImages", lambda idx: f"{idx:06d}.jpg", num_images)
datasets_utils.create_image_folder(base_dir, "SegmentationClass", lambda idx: f"{idx:06d}.png", num_images)
annotation = self._create_annotation_files(base_dir, "Annotations", num_images)
return dict(num_examples=num_images_per_image_set[image_set], annotation=annotation)
def _create_image_set_files(self, root, name, is_test_set):
root = pathlib.Path(root) / name
src = pathlib.Path(root) / "Main"
os.makedirs(src, exist_ok=True)
idcs = dict(train=(0, 1, 2), val=(3, 4), test=(5,))
idcs["trainval"] = (*idcs["train"], *idcs["val"])
for image_set in ("test",) if is_test_set else ("train", "val", "trainval"):
self._create_image_set_file(src, image_set, idcs[image_set])
shutil.copytree(src, root / "Segmentation")
num_images = max(itertools.chain(*idcs.values())) + 1
num_images_per_image_set = dict([(image_set, len(idcs_)) for image_set, idcs_ in idcs.items()])
return num_images, num_images_per_image_set
def _create_image_set_file(self, root, image_set, idcs):
with open(pathlib.Path(root) / f"{image_set}.txt", "w") as fh:
fh.writelines([f"{idx:06d}\n" for idx in idcs])
def _create_annotation_files(self, root, name, num_images):
root = pathlib.Path(root) / name
os.makedirs(root)
for idx in range(num_images):
annotation = self._create_annotation_file(root, f"{idx:06d}.xml")
return annotation
def _create_annotation_file(self, root, name):
def add_child(parent, name, text=None):
child = ET.SubElement(parent, name)
child.text = text
return child
def add_name(obj, name="dog"):
add_child(obj, "name", name)
return name
def add_bndbox(obj, bndbox=None):
if bndbox is None:
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
obj = add_child(obj, "bndbox")
for name, text in bndbox.items():
add_child(obj, name, text)
return bndbox
annotation = ET.Element("annotation")
obj = add_child(annotation, "object")
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
with open(pathlib.Path(root) / name, "wb") as fh:
fh.write(ET.tostring(annotation))
return data
class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection
FEATURE_TYPES = (PIL.Image.Image, dict)
def test_annotations(self):
with self.create_dataset() as (dataset, info):
_, target = dataset[0]
self.assertIn("annotation", target)
annotation = target["annotation"]
self.assertIn("object", annotation)
objects = annotation["object"]
self.assertEqual(len(objects), 1)
object = objects[0]
self.assertEqual(object, info["annotation"])
if __name__ == "__main__":
unittest.main()
......@@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
from .utils import download_url, verify_str_arg
from .voc import download_extract
from .utils import download_url, verify_str_arg, download_and_extract_archive
class SBDataset(VisionDataset):
......@@ -77,7 +76,7 @@ class SBDataset(VisionDataset):
mask_dir = os.path.join(sbd_root, 'cls')
if download:
download_extract(self.url, self.root, self.filename, self.md5)
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
old_path = os.path.join(extracted_ds_root, f)
......
......@@ -4,8 +4,9 @@ import collections
from .vision import VisionDataset
import xml.etree.ElementTree as ET
from PIL import Image
from typing import Any, Callable, Dict, Optional, Tuple
from .utils import download_url, verify_str_arg
from typing import Any, Callable, Dict, Optional, Tuple, List
from .utils import download_and_extract_archive, verify_str_arg
import warnings
DATASET_YEAR_DICT = {
'2012': {
......@@ -53,23 +54,10 @@ DATASET_YEAR_DICT = {
}
class VOCSegmentation(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
class _VOCBase(VisionDataset):
_SPLITS_DIR: str
_TARGET_DIR: str
_TARGET_FILE_EXT: str
def __init__(
self,
......@@ -81,39 +69,86 @@ class VOCSegmentation(VisionDataset):
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
self.year = year
if year == "2007" and image_set == "test":
year = "2007-test"
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
valid_sets = ["train", "trainval", "val"]
super().__init__(root, transforms, transform, target_transform)
if year == "2007-test":
valid_sets.append("test")
self.image_set = verify_str_arg(image_set, "image_set", valid_sets)
base_dir = DATASET_YEAR_DICT[year]['base_dir']
if image_set == "test":
warnings.warn(
"Acessing the test image set of the year 2007 with year='2007-test' is deprecated. "
"Please use the combination year='2007' and image_set='test' instead."
)
year = "2007"
else:
raise ValueError(
"In the test image set of the year 2007 only image_set='test' is allowed. "
"For all other image sets use year='2007' instead."
)
self.year = year
valid_image_sets = ["train", "trainval", "val"]
if year == "2007":
valid_image_sets.append("test")
key = "2007-test"
else:
key = year
self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
dataset_year_dict = DATASET_YEAR_DICT[key]
self.url = dataset_year_dict["url"]
self.filename = dataset_year_dict["filename"]
self.md5 = dataset_year_dict["md5"]
base_dir = dataset_year_dict["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClass')
if download:
download_extract(self.url, self.root, self.filename, self.md5)
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
image_dir = os.path.join(voc_root, "JPEGImages")
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))
target_dir = os.path.join(voc_root, self._TARGET_DIR)
self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
assert len(self.images) == len(self.targets)
def __len__(self) -> int:
return len(self.images)
class VOCSegmentation(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Segmentation"
_TARGET_DIR = "SegmentationClass"
_TARGET_FILE_EXT = ".png"
@property
def masks(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
......@@ -123,7 +158,7 @@ class VOCSegmentation(VisionDataset):
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('RGB')
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transforms is not None:
......@@ -131,17 +166,15 @@ class VOCSegmentation(VisionDataset):
return img, target
def __len__(self) -> int:
return len(self.images)
class VOCDetection(VisionDataset):
class VOCDetection(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
......@@ -154,50 +187,13 @@ class VOCDetection(VisionDataset):
and returns a transformed version.
"""
def __init__(
self,
root: str,
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
self.year = year
if year == "2007" and image_set == "test":
year = "2007-test"
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
valid_sets = ["train", "trainval", "val"]
if year == "2007-test":
valid_sets.append("test")
self.image_set = verify_str_arg(image_set, "image_set", valid_sets)
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
annotation_dir = os.path.join(voc_root, 'Annotations')
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
splits_dir = os.path.join(voc_root, 'ImageSets/Main')
_SPLITS_DIR = "Main"
_TARGET_DIR = "Annotations"
_TARGET_FILE_EXT = ".xml"
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
assert (len(self.images) == len(self.annotations))
@property
def annotations(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
......@@ -207,18 +203,14 @@ class VOCDetection(VisionDataset):
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert('RGB')
target = self.parse_voc_xml(
ET.parse(self.annotations[index]).getroot())
img = Image.open(self.images[index]).convert("RGB")
target = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self) -> int:
return len(self.images)
def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
voc_dict: Dict[str, Any] = {}
children = list(node)
......@@ -227,21 +219,11 @@ class VOCDetection(VisionDataset):
for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
if node.tag == 'annotation':
def_dic['object'] = [def_dic['object']]
voc_dict = {
node.tag:
{ind: v[0] if len(v) == 1 else v
for ind, v in def_dic.items()}
}
if node.tag == "annotation":
def_dic["object"] = [def_dic["object"]]
voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
if node.text:
text = node.text.strip()
if not children:
voc_dict[node.tag] = text
return voc_dict
def download_extract(url: str, root: str, filename: str, md5: str) -> None:
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment