"...text-generation-inference.git" did not exist on "cef0553d59713a5e5842b9a1d79334d4ffc066b9"
Unverified Commit 0acbf663 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.usps (#2538)

parent 3f70e3c4
from PIL import Image
import os
import numpy as np
from typing import Any, Callable, cast, Optional, Tuple
from .utils import download_url
from .vision import VisionDataset
......@@ -36,8 +37,14 @@ class USPS(VisionDataset):
],
}
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform)
split = 'train' if train else 'test'
......@@ -52,13 +59,13 @@ class USPS(VisionDataset):
raw_data = [line.decode().split() for line in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs
self.targets = targets
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -80,5 +87,5 @@ class USPS(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment