Unverified Commit 5405739e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

vision (#2526)

parent 3245b10d
import os import os
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from typing import Any, Callable, List, Optional, Tuple
class VisionDataset(data.Dataset): class VisionDataset(data.Dataset):
_repr_indent = 4 _repr_indent = 4
def __init__(self, root, transforms=None, transform=None, target_transform=None): def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
if isinstance(root, torch._six.string_classes): if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root) root = os.path.expanduser(root)
self.root = root self.root = root
...@@ -25,13 +32,13 @@ class VisionDataset(data.Dataset): ...@@ -25,13 +32,13 @@ class VisionDataset(data.Dataset):
transforms = StandardTransform(transform, target_transform) transforms = StandardTransform(transform, target_transform)
self.transforms = transforms self.transforms = transforms
def __getitem__(self, index): def __getitem__(self, index: int) -> Any:
raise NotImplementedError raise NotImplementedError
def __len__(self): def __len__(self) -> int:
raise NotImplementedError raise NotImplementedError
def __repr__(self): def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__ head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())] body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None: if self.root is not None:
...@@ -42,33 +49,33 @@ class VisionDataset(data.Dataset): ...@@ -42,33 +49,33 @@ class VisionDataset(data.Dataset):
lines = [head] + [" " * self._repr_indent + line for line in body] lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines) return '\n'.join(lines)
def _format_transform_repr(self, transform, head): def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines() lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] + return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]]) ["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self): def extra_repr(self) -> str:
return "" return ""
class StandardTransform(object): class StandardTransform(object):
def __init__(self, transform=None, target_transform=None): def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
def __call__(self, input, target): def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None: if self.transform is not None:
input = self.transform(input) input = self.transform(input)
if self.target_transform is not None: if self.target_transform is not None:
target = self.target_transform(target) target = self.target_transform(target)
return input, target return input, target
def _format_transform_repr(self, transform, head): def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines() lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] + return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]]) ["{}{}".format(" " * len(head), line) for line in lines[1:]])
def __repr__(self): def __repr__(self) -> str:
body = [self.__class__.__name__] body = [self.__class__.__name__]
if self.transform is not None: if self.transform is not None:
body += self._format_transform_repr(self.transform, body += self._format_transform_repr(self.transform,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment