Unverified Commit 0acbf663 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.usps (#2538)

parent 3f70e3c4
from PIL import Image from PIL import Image
import os import os
import numpy as np import numpy as np
from typing import Any, Callable, cast, Optional, Tuple
from .utils import download_url from .utils import download_url
from .vision import VisionDataset from .vision import VisionDataset
...@@ -36,8 +37,14 @@ class USPS(VisionDataset): ...@@ -36,8 +37,14 @@ class USPS(VisionDataset):
], ],
} }
def __init__(self, root, train=True, transform=None, target_transform=None, def __init__(
download=False): self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(USPS, self).__init__(root, transform=transform, super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
split = 'train' if train else 'test' split = 'train' if train else 'test'
...@@ -52,13 +59,13 @@ class USPS(VisionDataset): ...@@ -52,13 +59,13 @@ class USPS(VisionDataset):
raw_data = [line.decode().split() for line in fp.readlines()] raw_data = [line.decode().split() for line in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16)) imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data] targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs self.data = imgs
self.targets = targets self.targets = targets
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -80,5 +87,5 @@ class USPS(VisionDataset): ...@@ -80,5 +87,5 @@ class USPS(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment