Unverified Commit bf584072 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.sbd (#2535)

parent 49ec4a16
import os import os
import shutil import shutil
from .vision import VisionDataset from .vision import VisionDataset
from typing import Any, Callable, Optional, Tuple
import numpy as np import numpy as np
...@@ -49,12 +50,14 @@ class SBDataset(VisionDataset): ...@@ -49,12 +50,14 @@ class SBDataset(VisionDataset):
voc_split_filename = "train_noval.txt" voc_split_filename = "train_noval.txt"
voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722" voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
def __init__(self, def __init__(
root, self,
image_set='train', root: str,
mode='boundaries', image_set: str = "train",
download=False, mode: str = "boundaries",
transforms=None): download: bool = False,
transforms: Optional[Callable] = None,
) -> None:
try: try:
from scipy.io import loadmat from scipy.io import loadmat
...@@ -88,8 +91,8 @@ class SBDataset(VisionDataset): ...@@ -88,8 +91,8 @@ class SBDataset(VisionDataset):
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt') split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
with open(os.path.join(split_f), "r") as f: with open(os.path.join(split_f), "r") as fh:
file_names = [x.strip() for x in f.readlines()] file_names = [x.strip() for x in fh.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
...@@ -98,16 +101,16 @@ class SBDataset(VisionDataset): ...@@ -98,16 +101,16 @@ class SBDataset(VisionDataset):
self._get_target = self._get_segmentation_target \ self._get_target = self._get_segmentation_target \
if self.mode == "segmentation" else self._get_boundaries_target if self.mode == "segmentation" else self._get_boundaries_target
def _get_segmentation_target(self, filepath): def _get_segmentation_target(self, filepath: str) -> Image.Image:
mat = self._loadmat(filepath) mat = self._loadmat(filepath)
return Image.fromarray(mat['GTcls'][0]['Segmentation'][0]) return Image.fromarray(mat['GTcls'][0]['Segmentation'][0])
def _get_boundaries_target(self, filepath): def _get_boundaries_target(self, filepath: str) -> np.ndarray:
mat = self._loadmat(filepath) mat = self._loadmat(filepath)
return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0) return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0)
for i in range(self.num_classes)], axis=0) for i in range(self.num_classes)], axis=0)
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = Image.open(self.images[index]).convert('RGB') img = Image.open(self.images[index]).convert('RGB')
target = self._get_target(self.masks[index]) target = self._get_target(self.masks[index])
...@@ -116,9 +119,9 @@ class SBDataset(VisionDataset): ...@@ -116,9 +119,9 @@ class SBDataset(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.images) return len(self.images)
def extra_repr(self): def extra_repr(self) -> str:
lines = ["Image set: {image_set}", "Mode: {mode}"] lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__) return '\n'.join(lines).format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment