Unverified Commit 7c1ed419 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.voc (#2537)

parent 0acbf663
......@@ -4,6 +4,7 @@ import collections
from .vision import VisionDataset
import xml.etree.ElementTree as ET
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .utils import download_url, check_integrity, verify_str_arg
DATASET_YEAR_DICT = {
......@@ -70,14 +71,16 @@ class VOCSegmentation(VisionDataset):
and returns a transformed version.
"""
def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None,
target_transform=None,
transforms=None):
def __init__(
self,
root: str,
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
self.year = year
if year == "2007" and image_set == "test":
......@@ -112,7 +115,7 @@ class VOCSegmentation(VisionDataset):
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -128,7 +131,7 @@ class VOCSegmentation(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.images)
......@@ -151,14 +154,16 @@ class VOCDetection(VisionDataset):
and returns a transformed version.
"""
def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None,
target_transform=None,
transforms=None):
def __init__(
self,
root: str,
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
self.year = year
if year == "2007" and image_set == "test":
......@@ -194,7 +199,7 @@ class VOCDetection(VisionDataset):
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
assert (len(self.images) == len(self.annotations))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -211,14 +216,14 @@ class VOCDetection(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.images)
def parse_voc_xml(self, node):
voc_dict = {}
def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
voc_dict: Dict[str, Any] = {}
children = list(node)
if children:
def_dic = collections.defaultdict(list)
def_dic: Dict[str, Any] = collections.defaultdict(list)
for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
......@@ -236,7 +241,7 @@ class VOCDetection(VisionDataset):
return voc_dict
def download_extract(url, root, filename, md5):
def download_extract(url: str, root: str, filename: str, md5: str) -> None:
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment