Unverified Commit ec9c7a54 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.lsun (#2530)

parent 5f2e140a
......@@ -6,11 +6,15 @@ import io
import string
from collections.abc import Iterable
import pickle
from typing import Any, Callable, cast, List, Optional, Tuple, Union
from .utils import verify_str_arg, iterable_to_str
class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
def __init__(
self, root: str, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None
) -> None:
import lmdb
super(LSUNClass, self).__init__(root, transform=transform,
target_transform=target_transform)
......@@ -27,7 +31,7 @@ class LSUNClass(VisionDataset):
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = None, None
env = self.env
with env.begin(write=False) as txn:
......@@ -46,7 +50,7 @@ class LSUNClass(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return self.length
......@@ -64,7 +68,13 @@ class LSUN(VisionDataset):
target and transforms it.
"""
def __init__(self, root, classes='train', transform=None, target_transform=None):
def __init__(
self,
root: str,
classes: Union[str, List[str]] = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes = self._verify_classes(classes)
......@@ -84,13 +94,14 @@ class LSUN(VisionDataset):
self.length = count
def _verify_classes(self, classes):
def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
try:
classes = cast(str, classes)
verify_str_arg(classes, "classes", dset_opts)
if classes == 'test':
classes = [classes]
......@@ -120,7 +131,7 @@ class LSUN(VisionDataset):
return classes
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -145,8 +156,8 @@ class LSUN(VisionDataset):
img, _ = db[index]
return img, target
def __len__(self):
def __len__(self) -> int:
return self.length
def extra_repr(self):
def extra_repr(self) -> str:
return "Classes: {classes}".format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment