"src/client/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "fa0e717d529d925f2d9d833559dd3bd04816d6da"
Unverified Commit 4ba91bff authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make mypy more strict for prototype datasets (#4513)

* make mypy more strict for prototype datasets

* fix code format

* apply strictness only to datasets

* fix more mypy issues

* cleanup

* fix mnist annotations

* refactor celeba

* warn on redundant casts

* remove redundant cast

* simplify annotation

* fix import
parent 9407b45a
...@@ -4,6 +4,23 @@ files = torchvision ...@@ -4,6 +4,23 @@ files = torchvision
show_error_codes = True show_error_codes = True
pretty = True pretty = True
allow_redefinition = True allow_redefinition = True
warn_redundant_casts = True
[mypy-torchvision.prototype.datasets.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.io._video_opt.*] [mypy-torchvision.io._video_opt.*]
......
import os import os
from typing import Any, Callable, cast, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -63,7 +63,7 @@ class USPS(VisionDataset): ...@@ -63,7 +63,7 @@ class USPS(VisionDataset):
raw_data = [line.decode().split() for line in fp.readlines()] raw_data = [line.decode().split() for line in fp.readlines()]
tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8) imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data] targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs self.data = imgs
......
...@@ -82,7 +82,10 @@ class Caltech101(Dataset): ...@@ -82,7 +82,10 @@ class Caltech101(Dataset):
return category, id return category, id
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] self,
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
key, image_data, ann_data = data key, image_data, ann_data = data
category, _ = key category, _ = key
...@@ -117,11 +120,11 @@ class Caltech101(Dataset): ...@@ -117,11 +120,11 @@ class Caltech101(Dataset):
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
images_dp = TarArchiveReader(images_dp) images_dp = TarArchiveReader(images_dp)
images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image) images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = TarArchiveReader(anns_dp) anns_dp = TarArchiveReader(anns_dp)
anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann) anns_dp = Filter(anns_dp, self._is_ann)
dp = KeyZipper( dp = KeyZipper(
images_dp, images_dp,
...@@ -136,7 +139,7 @@ class Caltech101(Dataset): ...@@ -136,7 +139,7 @@ class Caltech101(Dataset):
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image) dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) return sorted({pathlib.Path(path).parent.name for path, _ in dp})
...@@ -185,7 +188,7 @@ class Caltech256(Dataset): ...@@ -185,7 +188,7 @@ class Caltech256(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
import csv import csv
import io import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -23,37 +23,38 @@ from torchvision.prototype.datasets.utils import ( ...@@ -23,37 +23,38 @@ from torchvision.prototype.datasets.utils import (
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor
class CelebACSVParser(IterDataPipe): csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__( def __init__(
self, self,
datapipe, datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
*, *,
has_header, fieldnames: Optional[Sequence[str]] = None,
): ) -> None:
self.datapipe = datapipe self.datapipe = datapipe
self.has_header = has_header self.fieldnames = fieldnames
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
def __iter__(self): def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe: for _, file in self.datapipe:
file = (line.decode() for line in file) file = (line.decode() for line in file)
if self.has_header: if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples # The first row is skipped, because it only contains the number of samples
next(file) next(file)
# Empty field names are filtered out, because some files have an extr white space after the header # Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column # line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name] fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column # Some files do not include a label for the image ID column
if fieldnames[0] != "image_id": if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id") fieldnames.insert(0, "image_id")
for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams): for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line yield line.pop("image_id"), line
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]
class CelebA(Dataset): class CelebA(Dataset):
...@@ -104,13 +105,10 @@ class CelebA(Dataset): ...@@ -104,13 +105,10 @@ class CelebA(Dataset):
"2": "test", "2": "test",
} }
def _filter_split(self, data: Tuple[str, str], *, split): def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
_, split_id = data return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
def _collate_anns( def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]:
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data (image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks) return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)
...@@ -127,7 +125,7 @@ class CelebA(Dataset): ...@@ -127,7 +125,7 @@ class CelebA(Dataset):
image = decoder(buffer) if decoder else buffer image = decoder(buffer) if decoder else buffer
identity = torch.tensor(int(ann["identity"][0])) identity = int(ann["identity"]["identity"])
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = { landmarks = {
...@@ -153,24 +151,24 @@ class CelebA(Dataset): ...@@ -153,24 +151,24 @@ class CelebA(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, has_header=False) splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = ZipArchiveReader(images_dp) images_dp = ZipArchiveReader(images_dp)
anns_dp: IterDataPipe = Zipper( anns_dp = Zipper(
*[ *[
CelebACSVParser(dp, has_header=has_header) CelebACSVParser(dp, fieldnames=fieldnames)
for dp, has_header in ( for dp, fieldnames in (
(identities_dp, False), (identities_dp, ("image_id", "identity")),
(attributes_dp, True), (attributes_dp, None),
(bboxes_dp, True), (bboxes_dp, None),
(landmarks_dp, True), (landmarks_dp, None),
) )
] ]
) )
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns) anns_dp = Mapper(anns_dp, self._collate_anns)
dp = KeyZipper( dp = KeyZipper(
splits_dp, splits_dp,
......
...@@ -3,7 +3,7 @@ import functools ...@@ -3,7 +3,7 @@ import functools
import io import io
import pathlib import pathlib
import pickle import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast
import numpy as np import numpy as np
import torch import torch
...@@ -56,7 +56,7 @@ class _CifarBase(Dataset): ...@@ -56,7 +56,7 @@ class _CifarBase(Dataset):
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data _, file = data
return pickle.load(file, encoding="latin1") return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
def _collate_and_decode( def _collate_and_decode(
self, self,
...@@ -86,9 +86,9 @@ class _CifarBase(Dataset): ...@@ -86,9 +86,9 @@ class _CifarBase(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp: IterDataPipe = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config)) dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp: IterDataPipe = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
...@@ -96,9 +96,9 @@ class _CifarBase(Dataset): ...@@ -96,9 +96,9 @@ class _CifarBase(Dataset):
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
return next(iter(dp))[self._CATEGORIES_KEY] return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
class Cifar10(_CifarBase): class Cifar10(_CifarBase):
...@@ -133,9 +133,9 @@ class Cifar100(_CifarBase): ...@@ -133,9 +133,9 @@ class Cifar100(_CifarBase):
_META_FILE_NAME = "meta" _META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names" _CATEGORIES_KEY = "fine_label_names"
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.name == config.split return path.name == cast(str, config.split)
@property @property
def info(self) -> DatasetInfo: def info(self) -> DatasetInfo:
......
import io import io
import pathlib import pathlib
import re import re
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import torch import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
...@@ -44,11 +44,11 @@ class ImageNet(Dataset): ...@@ -44,11 +44,11 @@ class ImageNet(Dataset):
@property @property
def category_to_wnid(self) -> Dict[str, str]: def category_to_wnid(self) -> Dict[str, str]:
return self.info.extra.category_to_wnid return cast(Dict[str, str], self.info.extra.category_to_wnid)
@property @property
def wnid_to_category(self) -> Dict[str, str]: def wnid_to_category(self) -> Dict[str, str]:
return self.info.extra.wnid_to_category return cast(Dict[str, str], self.info.extra.wnid_to_category)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train": if config.split == "train":
...@@ -152,7 +152,7 @@ class ImageNet(Dataset): ...@@ -152,7 +152,7 @@ class ImageNet(Dataset):
"n03710721": "tank suit", "n03710721": "tank suit",
} }
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]: def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config) resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name) devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp) devkit_dp = TarArchiveReader(devkit_dp)
...@@ -160,12 +160,15 @@ class ImageNet(Dataset): ...@@ -160,12 +160,15 @@ class ImageNet(Dataset):
meta = next(iter(devkit_dp))[1] meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"] synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = [ categories_and_wnids = cast(
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) List[Tuple[str, ...]],
for _, wnid, category, _, num_children, *_ in synsets [
# if num_children > 0, we are looking at a superclass that has no direct instance (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
if num_children == 0 for _, wnid, category, _, num_children, *_ in synsets
] # if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
],
)
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids return categories_and_wnids
...@@ -38,7 +38,7 @@ __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] ...@@ -38,7 +38,7 @@ __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
prod = functools.partial(functools.reduce, operator.mul) prod = functools.partial(functools.reduce, operator.mul)
class MNISTFileReader(IterDataPipe): class MNISTFileReader(IterDataPipe[np.ndarray]):
_DTYPE_MAP = { _DTYPE_MAP = {
8: "u1", # uint8 8: "u1", # uint8
9: "i1", # int8 9: "i1", # int8
...@@ -48,13 +48,15 @@ class MNISTFileReader(IterDataPipe): ...@@ -48,13 +48,15 @@ class MNISTFileReader(IterDataPipe):
14: "f8", # float64 14: "f8", # float64
} }
def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Optional[int]) -> None: def __init__(
self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int]
) -> None:
self.datapipe = datapipe self.datapipe = datapipe
self.start = start self.start = start
self.stop = stop self.stop = stop
@staticmethod @staticmethod
def _decode(bytes): def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16) return int(codecs.encode(bytes, "hex"), 16)
def __iter__(self) -> Iterator[np.ndarray]: def __iter__(self) -> Iterator[np.ndarray]:
...@@ -107,7 +109,7 @@ class _MNISTBase(Dataset): ...@@ -107,7 +109,7 @@ class _MNISTBase(Dataset):
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
): ) -> Dict[str, Any]:
image_array, label_array = data image_array, label_array = data
image: Union[torch.Tensor, io.BytesIO] image: Union[torch.Tensor, io.BytesIO]
...@@ -138,14 +140,14 @@ class _MNISTBase(Dataset): ...@@ -138,14 +140,14 @@ class _MNISTBase(Dataset):
labels_dp = Decompressor(labels_dp) labels_dp = Decompressor(labels_dp)
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp: IterDataPipe = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
class MNIST(_MNISTBase): class MNIST(_MNISTBase):
@property @property
def info(self): def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"mnist", "mnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -176,7 +178,7 @@ class MNIST(_MNISTBase): ...@@ -176,7 +178,7 @@ class MNIST(_MNISTBase):
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
@property @property
def info(self): def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"fashionmnist", "fashionmnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -209,7 +211,7 @@ class FashionMNIST(MNIST): ...@@ -209,7 +211,7 @@ class FashionMNIST(MNIST):
class KMNIST(MNIST): class KMNIST(MNIST):
@property @property
def info(self): def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"kmnist", "kmnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -231,7 +233,7 @@ class KMNIST(MNIST): ...@@ -231,7 +233,7 @@ class KMNIST(MNIST):
class EMNIST(_MNISTBase): class EMNIST(_MNISTBase):
@property @property
def info(self): def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"emnist", "emnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -295,7 +297,7 @@ class EMNIST(_MNISTBase): ...@@ -295,7 +297,7 @@ class EMNIST(_MNISTBase):
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
): ) -> Dict[str, Any]:
image_array, label_array = data image_array, label_array = data
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
...@@ -321,7 +323,7 @@ class EMNIST(_MNISTBase): ...@@ -321,7 +323,7 @@ class EMNIST(_MNISTBase):
images_dp, labels_dp = Demultiplexer( images_dp, labels_dp = Demultiplexer(
archive_dp, archive_dp,
2, 2,
functools.partial(self._classify_archive, config=config), # type:ignore[arg-type] functools.partial(self._classify_archive, config=config),
drop_none=True, drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
...@@ -330,7 +332,7 @@ class EMNIST(_MNISTBase): ...@@ -330,7 +332,7 @@ class EMNIST(_MNISTBase):
class QMNIST(_MNISTBase): class QMNIST(_MNISTBase):
@property @property
def info(self): def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"qmnist", "qmnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -381,7 +383,7 @@ class QMNIST(_MNISTBase): ...@@ -381,7 +383,7 @@ class QMNIST(_MNISTBase):
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
): ) -> Dict[str, Any]:
image_array, label_array = data image_array, label_array = data
label_parts = label_array.tolist() label_parts = label_array.tolist()
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder)
......
import io import io
import pathlib import pathlib
import re import re
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import numpy as np import numpy as np
import torch import torch
...@@ -135,7 +135,7 @@ class SBD(Dataset): ...@@ -135,7 +135,7 @@ class SBD(Dataset):
split_dp, images_dp, anns_dp = Demultiplexer( split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp, archive_dp,
3, 3,
self._classify_archive, # type: ignore[arg-type] self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True, drop_none=True,
) )
...@@ -159,15 +159,21 @@ class SBD(Dataset): ...@@ -159,15 +159,21 @@ class SBD(Dataset):
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m")) dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp) dp = LineReader(dp)
dp: IterDataPipe = Mapper(dp, bytes.decode, input_col=1) dp = Mapper(dp, bytes.decode, input_col=1)
lines = tuple(zip(*iter(dp)))[1] lines = tuple(zip(*iter(dp)))[1]
pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)") pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)")
categories_and_labels = [ categories_and_labels = cast(
pattern.match(line).groups() # type: ignore[union-attr] List[Tuple[str, ...]],
# the first and last line contain no information [
for line in lines[1:-1] pattern.match(line).groups() # type: ignore[union-attr]
] # the first and last line contain no information
return tuple(zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))))[0] for line in lines[1:-1]
],
)
categories_and_labels.sort(key=lambda category_and_label: int(category_and_label[1]))
categories, _ = zip(*categories_and_labels)
return categories
...@@ -92,7 +92,11 @@ class VOC(Dataset): ...@@ -92,7 +92,11 @@ class VOC(Dataset):
return torch.tensor(bboxes) return torch.tensor(bboxes)
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, data, *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
split_and_image_data, ann_data = data split_and_image_data, ann_data = data
_, image_data = split_and_image_data _, image_data = split_and_image_data
...@@ -104,7 +108,7 @@ class VOC(Dataset): ...@@ -104,7 +108,7 @@ class VOC(Dataset):
if config.task == "detection": if config.task == "detection":
ann = self._decode_detection_ann(ann_buffer) ann = self._decode_detection_ann(ann_buffer)
else: # config.task == "segmentation": else: # config.task == "segmentation":
ann = decoder(ann_buffer) if decoder else ann_buffer ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment]
return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann) return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann)
...@@ -120,15 +124,13 @@ class VOC(Dataset): ...@@ -120,15 +124,13 @@ class VOC(Dataset):
split_dp, images_dp, anns_dp = Demultiplexer( split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp, archive_dp,
3, 3,
functools.partial(self._classify_archive, config=config), # type: ignore[arg-type] functools.partial(self._classify_archive, config=config),
drop_none=True, drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
split_dp: IterDataPipe = Filter( split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
)
split_dp: IterDataPipe = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE) split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
......
...@@ -25,7 +25,7 @@ def _collate_and_decode_data( ...@@ -25,7 +25,7 @@ def _collate_and_decode_data(
*, *,
root: pathlib.Path, root: pathlib.Path,
categories: List[str], categories: List[str],
decoder, decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
path, buffer = data path, buffer = data
data = decoder(buffer) if decoder else buffer data = decoder(buffer) if decoder else buffer
......
# type: ignore
import argparse import argparse
import collections.abc import collections.abc
import contextlib import contextlib
......
import io import io
from typing import cast
import PIL.Image import PIL.Image
import torch import torch
...@@ -12,4 +13,4 @@ def raw(buffer: io.IOBase) -> torch.Tensor: ...@@ -12,4 +13,4 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())) return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())))
# type: ignore
import argparse import argparse
import csv import csv
import sys import sys
...@@ -50,7 +52,7 @@ def parse_args(argv=None): ...@@ -50,7 +52,7 @@ def parse_args(argv=None):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args(["-f", "sbd"])
try: try:
main(*args.names, force=args.force) main(*args.names, force=args.force)
......
...@@ -3,16 +3,7 @@ import csv ...@@ -3,16 +3,7 @@ import csv
import enum import enum
import io import io
import pathlib import pathlib
from typing import ( from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Union,
Tuple,
)
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
......
...@@ -9,7 +9,6 @@ import os ...@@ -9,7 +9,6 @@ import os
import os.path import os.path
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Mapping
from typing import ( from typing import (
Collection, Collection,
Sequence, Sequence,
...@@ -23,13 +22,14 @@ from typing import ( ...@@ -23,13 +22,14 @@ from typing import (
Optional, Optional,
NoReturn, NoReturn,
Iterable, Iterable,
Mapping,
) )
from typing import cast
import numpy as np import numpy as np
import PIL.Image import PIL.Image
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
"BUILTIN_DIR", "BUILTIN_DIR",
...@@ -83,7 +83,7 @@ def add_suggestion( ...@@ -83,7 +83,7 @@ def add_suggestion(
return f"{msg.strip()} {hint}" return f"{msg.strip()} {hint}"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]): def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str: def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items]) return sep.join([f"{key}={value}" for key, value in items])
...@@ -101,29 +101,29 @@ def make_repr(name: str, items: Iterable[Tuple[str, Any]]): ...@@ -101,29 +101,29 @@ def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
return f"{prefix}\n{body}\n{postfix}" return f"{prefix}\n{body}\n{postfix}"
class FrozenMapping(Mapping): class FrozenMapping(Mapping[K, D]):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs) data = dict(*args, **kwargs)
self.__dict__["__data__"] = data self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items())) self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any: def __getitem__(self, item: K) -> D:
return self.__dict__["__data__"][name] return cast(Mapping[K, D], self.__dict__["__data__"])[item]
def __iter__(self): def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys()) return iter(self.__dict__["__data__"].keys())
def __len__(self): def __len__(self) -> int:
return len(self.__dict__["__data__"]) return len(self.__dict__["__data__"])
def __setitem__(self, key: Any, value: Any) -> NoReturn: def __setitem__(self, key: K, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable") raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn: def __delitem__(self, key: K) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable") raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int: def __hash__(self) -> int:
return self.__dict__["__final_hash__"] return cast(int, self.__dict__["__final_hash__"])
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping): if not isinstance(other, FrozenMapping):
...@@ -131,7 +131,7 @@ class FrozenMapping(Mapping): ...@@ -131,7 +131,7 @@ class FrozenMapping(Mapping):
return hash(self) == hash(other) return hash(self) == hash(other)
def __repr__(self): def __repr__(self) -> str:
return repr(self.__dict__["__data__"]) return repr(self.__dict__["__data__"])
...@@ -205,7 +205,7 @@ class Enumerator(IterDataPipe[Tuple[int, D]]): ...@@ -205,7 +205,7 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
def getitem(*items: Any) -> Callable[[Any], Any]: def getitem(*items: Any) -> Callable[[Any], Any]:
def wrapper(obj: Any): def wrapper(obj: Any) -> Any:
for item in items: for item in items:
obj = obj[item] obj = obj[item]
return obj return obj
...@@ -218,7 +218,7 @@ def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[ ...@@ -218,7 +218,7 @@ def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[
name = getter name = getter
def getter(path: pathlib.Path) -> D: def getter(path: pathlib.Path) -> D:
return getattr(path, name) return cast(D, getattr(path, name))
def wrapper(data: Tuple[str, Any]) -> D: def wrapper(data: Tuple[str, Any]) -> D:
return getter(pathlib.Path(data[0])) # type: ignore[operator] return getter(pathlib.Path(data[0])) # type: ignore[operator]
......
...@@ -8,7 +8,7 @@ from torch.utils.data.datapipes.iter import FileLoader, IterableWrapper ...@@ -8,7 +8,7 @@ from torch.utils.data.datapipes.iter import FileLoader, IterableWrapper
# FIXME # FIXME
def compute_sha256(_) -> str: def compute_sha256(path: pathlib.Path) -> str:
return "" return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment