Unverified Commit bf584072 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.sbd (#2535)

parent 49ec4a16
import os
import shutil
from .vision import VisionDataset
from typing import Any, Callable, Optional, Tuple
import numpy as np
......@@ -49,12 +50,14 @@ class SBDataset(VisionDataset):
voc_split_filename = "train_noval.txt"
voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
def __init__(self,
root,
image_set='train',
mode='boundaries',
download=False,
transforms=None):
def __init__(
self,
root: str,
image_set: str = "train",
mode: str = "boundaries",
download: bool = False,
transforms: Optional[Callable] = None,
) -> None:
try:
from scipy.io import loadmat
......@@ -88,8 +91,8 @@ class SBDataset(VisionDataset):
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
with open(os.path.join(split_f), "r") as fh:
file_names = [x.strip() for x in fh.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
......@@ -98,16 +101,16 @@ class SBDataset(VisionDataset):
self._get_target = self._get_segmentation_target \
if self.mode == "segmentation" else self._get_boundaries_target
def _get_segmentation_target(self, filepath):
def _get_segmentation_target(self, filepath: str) -> Image.Image:
mat = self._loadmat(filepath)
return Image.fromarray(mat['GTcls'][0]['Segmentation'][0])
def _get_boundaries_target(self, filepath):
def _get_boundaries_target(self, filepath: str) -> np.ndarray:
mat = self._loadmat(filepath)
return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0)
for i in range(self.num_classes)], axis=0)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = Image.open(self.images[index]).convert('RGB')
target = self._get_target(self.masks[index])
......@@ -116,9 +119,9 @@ class SBDataset(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.images)
def extra_repr(self):
def extra_repr(self) -> str:
lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment