Unverified Commit 203a7841 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints to torchvision.datasets.sbu (#2536)

parent bf584072
from PIL import Image from PIL import Image
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
from typing import Any, Callable, Optional, Tuple
import os import os
from .vision import VisionDataset from .vision import VisionDataset
...@@ -23,7 +24,13 @@ class SBU(VisionDataset): ...@@ -23,7 +24,13 @@ class SBU(VisionDataset):
filename = "SBUCaptionedPhotoDataset.tar.gz" filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285' md5_checksum = '9aec147b3488753cf758b4d493422285'
def __init__(self, root, transform=None, target_transform=None, download=True): def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
super(SBU, self).__init__(root, transform=transform, super(SBU, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -50,7 +57,7 @@ class SBU(VisionDataset): ...@@ -50,7 +57,7 @@ class SBU(VisionDataset):
self.photos.append(photo) self.photos.append(photo)
self.captions.append(caption) self.captions.append(caption)
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -69,11 +76,11 @@ class SBU(VisionDataset): ...@@ -69,11 +76,11 @@ class SBU(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
"""The number of photos in the dataset.""" """The number of photos in the dataset."""
return len(self.photos) return len(self.photos)
def _check_integrity(self): def _check_integrity(self) -> bool:
"""Check the md5 checksum of the downloaded tarball.""" """Check the md5 checksum of the downloaded tarball."""
root = self.root root = self.root
fpath = os.path.join(root, self.filename) fpath = os.path.join(root, self.filename)
...@@ -81,7 +88,7 @@ class SBU(VisionDataset): ...@@ -81,7 +88,7 @@ class SBU(VisionDataset):
return False return False
return True return True
def download(self): def download(self) -> None:
"""Download and extract the tarball, and download each individual photo.""" """Download and extract the tarball, and download each individual photo."""
import tarfile import tarfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment