Unverified Commit ec9c7a54 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.lsun (#2530)

parent 5f2e140a
...@@ -6,11 +6,15 @@ import io ...@@ -6,11 +6,15 @@ import io
import string import string
from collections.abc import Iterable from collections.abc import Iterable
import pickle import pickle
from typing import Any, Callable, cast, List, Optional, Tuple, Union
from .utils import verify_str_arg, iterable_to_str from .utils import verify_str_arg, iterable_to_str
class LSUNClass(VisionDataset): class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None): def __init__(
self, root: str, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None
) -> None:
import lmdb import lmdb
super(LSUNClass, self).__init__(root, transform=transform, super(LSUNClass, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -27,7 +31,7 @@ class LSUNClass(VisionDataset): ...@@ -27,7 +31,7 @@ class LSUNClass(VisionDataset):
self.keys = [key for key, _ in txn.cursor()] self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb")) pickle.dump(self.keys, open(cache_file, "wb"))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = None, None img, target = None, None
env = self.env env = self.env
with env.begin(write=False) as txn: with env.begin(write=False) as txn:
...@@ -46,7 +50,7 @@ class LSUNClass(VisionDataset): ...@@ -46,7 +50,7 @@ class LSUNClass(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return self.length return self.length
...@@ -64,7 +68,13 @@ class LSUN(VisionDataset): ...@@ -64,7 +68,13 @@ class LSUN(VisionDataset):
target and transforms it. target and transforms it.
""" """
def __init__(self, root, classes='train', transform=None, target_transform=None): def __init__(
self,
root: str,
classes: Union[str, List[str]] = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(LSUN, self).__init__(root, transform=transform, super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.classes = self._verify_classes(classes) self.classes = self._verify_classes(classes)
...@@ -84,13 +94,14 @@ class LSUN(VisionDataset): ...@@ -84,13 +94,14 @@ class LSUN(VisionDataset):
self.length = count self.length = count
def _verify_classes(self, classes): def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen', 'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower'] 'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test'] dset_opts = ['train', 'val', 'test']
try: try:
classes = cast(str, classes)
verify_str_arg(classes, "classes", dset_opts) verify_str_arg(classes, "classes", dset_opts)
if classes == 'test': if classes == 'test':
classes = [classes] classes = [classes]
...@@ -120,7 +131,7 @@ class LSUN(VisionDataset): ...@@ -120,7 +131,7 @@ class LSUN(VisionDataset):
return classes return classes
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -145,8 +156,8 @@ class LSUN(VisionDataset): ...@@ -145,8 +156,8 @@ class LSUN(VisionDataset):
img, _ = db[index] img, _ = db[index]
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return self.length return self.length
def extra_repr(self): def extra_repr(self) -> str:
return "Classes: {classes}".format(**self.__dict__) return "Classes: {classes}".format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment