Unverified Commit 7c1ed419 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.voc (#2537)

parent 0acbf663
...@@ -4,6 +4,7 @@ import collections ...@@ -4,6 +4,7 @@ import collections
from .vision import VisionDataset from .vision import VisionDataset
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from PIL import Image from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .utils import download_url, check_integrity, verify_str_arg from .utils import download_url, check_integrity, verify_str_arg
DATASET_YEAR_DICT = { DATASET_YEAR_DICT = {
...@@ -70,14 +71,16 @@ class VOCSegmentation(VisionDataset): ...@@ -70,14 +71,16 @@ class VOCSegmentation(VisionDataset):
and returns a transformed version. and returns a transformed version.
""" """
def __init__(self, def __init__(
root, self,
year='2012', root: str,
image_set='train', year: str = "2012",
download=False, image_set: str = "train",
transform=None, download: bool = False,
target_transform=None, transform: Optional[Callable] = None,
transforms=None): target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform) super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
self.year = year self.year = year
if year == "2007" and image_set == "test": if year == "2007" and image_set == "test":
...@@ -112,7 +115,7 @@ class VOCSegmentation(VisionDataset): ...@@ -112,7 +115,7 @@ class VOCSegmentation(VisionDataset):
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks)) assert (len(self.images) == len(self.masks))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -128,7 +131,7 @@ class VOCSegmentation(VisionDataset): ...@@ -128,7 +131,7 @@ class VOCSegmentation(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.images) return len(self.images)
...@@ -151,14 +154,16 @@ class VOCDetection(VisionDataset): ...@@ -151,14 +154,16 @@ class VOCDetection(VisionDataset):
and returns a transformed version. and returns a transformed version.
""" """
def __init__(self, def __init__(
root, self,
year='2012', root: str,
image_set='train', year: str = "2012",
download=False, image_set: str = "train",
transform=None, download: bool = False,
target_transform=None, transform: Optional[Callable] = None,
transforms=None): target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super(VOCDetection, self).__init__(root, transforms, transform, target_transform) super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
self.year = year self.year = year
if year == "2007" and image_set == "test": if year == "2007" and image_set == "test":
...@@ -194,7 +199,7 @@ class VOCDetection(VisionDataset): ...@@ -194,7 +199,7 @@ class VOCDetection(VisionDataset):
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names] self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
assert (len(self.images) == len(self.annotations)) assert (len(self.images) == len(self.annotations))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -211,14 +216,14 @@ class VOCDetection(VisionDataset): ...@@ -211,14 +216,14 @@ class VOCDetection(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.images) return len(self.images)
def parse_voc_xml(self, node): def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
voc_dict = {} voc_dict: Dict[str, Any] = {}
children = list(node) children = list(node)
if children: if children:
def_dic = collections.defaultdict(list) def_dic: Dict[str, Any] = collections.defaultdict(list)
for dc in map(self.parse_voc_xml, children): for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items(): for ind, v in dc.items():
def_dic[ind].append(v) def_dic[ind].append(v)
...@@ -236,7 +241,7 @@ class VOCDetection(VisionDataset): ...@@ -236,7 +241,7 @@ class VOCDetection(VisionDataset):
return voc_dict return voc_dict
def download_extract(url, root, filename, md5): def download_extract(url: str, root: str, filename: str, md5: str) -> None:
download_url(url, root, filename, md5) download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar: with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root) tar.extractall(path=root)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment