Unverified Commit 4ba91bff authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make mypy more strict for prototype datasets (#4513)

* make mypy more strict for prototype datasets

* fix code format

* apply strictness only to datasets

* fix more mypy issues

* cleanup

* fix mnist annotations

* refactor celeba

* warn on redundant casts

* remove redundant cast

* simplify annotation

* fix import
parent 9407b45a
......@@ -4,6 +4,23 @@ files = torchvision
show_error_codes = True
pretty = True
allow_redefinition = True
warn_redundant_casts = True
[mypy-torchvision.prototype.datasets.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.io._video_opt.*]
......
import os
from typing import Any, Callable, cast, Optional, Tuple
from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
......@@ -63,7 +63,7 @@ class USPS(VisionDataset):
raw_data = [line.decode().split() for line in fp.readlines()]
tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs
......
......@@ -82,7 +82,10 @@ class Caltech101(Dataset):
return category, id
def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
self,
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
key, image_data, ann_data = data
category, _ = key
......@@ -117,11 +120,11 @@ class Caltech101(Dataset):
images_dp, anns_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image)
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = TarArchiveReader(anns_dp)
anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann)
anns_dp = Filter(anns_dp, self._is_ann)
dp = KeyZipper(
images_dp,
......@@ -136,7 +139,7 @@ class Caltech101(Dataset):
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
......@@ -185,7 +188,7 @@ class Caltech256(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_rogue_file)
dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
import csv
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
import torch
from torchdata.datapipes.iter import (
......@@ -23,37 +23,38 @@ from torchvision.prototype.datasets.utils import (
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor
class CelebACSVParser(IterDataPipe):
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__(
self,
datapipe,
datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
*,
has_header,
):
fieldnames: Optional[Sequence[str]] = None,
) -> None:
self.datapipe = datapipe
self.has_header = has_header
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
self.fieldnames = fieldnames
def __iter__(self):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)
if self.has_header:
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)
# Empty field names are filtered out, because some files have an extr white space after the header
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
yield line.pop("image_id"), line
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
class CelebA(Dataset):
......@@ -104,13 +105,10 @@ class CelebA(Dataset):
"2": "test",
}
def _filter_split(self, data: Tuple[str, str], *, split):
_, split_id = data
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
def _collate_anns(
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)
......@@ -127,7 +125,7 @@ class CelebA(Dataset):
image = decoder(buffer) if decoder else buffer
identity = torch.tensor(int(ann["identity"][0]))
identity = int(ann["identity"]["identity"])
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
......@@ -153,24 +151,24 @@ class CelebA(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, has_header=False)
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = ZipArchiveReader(images_dp)
anns_dp: IterDataPipe = Zipper(
anns_dp = Zipper(
*[
CelebACSVParser(dp, has_header=has_header)
for dp, has_header in (
(identities_dp, False),
(attributes_dp, True),
(bboxes_dp, True),
(landmarks_dp, True),
CelebACSVParser(dp, fieldnames=fieldnames)
for dp, fieldnames in (
(identities_dp, ("image_id", "identity")),
(attributes_dp, None),
(bboxes_dp, None),
(landmarks_dp, None),
)
]
)
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
anns_dp = Mapper(anns_dp, self._collate_anns)
dp = KeyZipper(
splits_dp,
......
......@@ -3,7 +3,7 @@ import functools
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast
import numpy as np
import torch
......@@ -56,7 +56,7 @@ class _CifarBase(Dataset):
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return pickle.load(file, encoding="latin1")
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
def _collate_and_decode(
self,
......@@ -86,9 +86,9 @@ class _CifarBase(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp: IterDataPipe = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
dp: IterDataPipe = Mapper(dp, self._unpickle)
dp = TarArchiveReader(dp)
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
......@@ -96,9 +96,9 @@ class _CifarBase(Dataset):
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle)
return next(iter(dp))[self._CATEGORIES_KEY]
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
class Cifar10(_CifarBase):
......@@ -133,9 +133,9 @@ class Cifar100(_CifarBase):
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
return path.name == cast(str, config.split)
@property
def info(self) -> DatasetInfo:
......
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
......@@ -44,11 +44,11 @@ class ImageNet(Dataset):
@property
def category_to_wnid(self) -> Dict[str, str]:
return self.info.extra.category_to_wnid
return cast(Dict[str, str], self.info.extra.category_to_wnid)
@property
def wnid_to_category(self) -> Dict[str, str]:
return self.info.extra.wnid_to_category
return cast(Dict[str, str], self.info.extra.wnid_to_category)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
......@@ -152,7 +152,7 @@ class ImageNet(Dataset):
"n03710721": "tank suit",
}
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]:
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
......@@ -160,12 +160,15 @@ class ImageNet(Dataset):
meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
categories_and_wnids = cast(
List[Tuple[str, ...]],
[
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
],
)
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids
......@@ -38,7 +38,7 @@ __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
prod = functools.partial(functools.reduce, operator.mul)
class MNISTFileReader(IterDataPipe):
class MNISTFileReader(IterDataPipe[np.ndarray]):
_DTYPE_MAP = {
8: "u1", # uint8
9: "i1", # int8
......@@ -48,13 +48,15 @@ class MNISTFileReader(IterDataPipe):
14: "f8", # float64
}
def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Optional[int]) -> None:
def __init__(
self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int]
) -> None:
self.datapipe = datapipe
self.start = start
self.stop = stop
@staticmethod
def _decode(bytes):
def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16)
def __iter__(self) -> Iterator[np.ndarray]:
......@@ -107,7 +109,7 @@ class _MNISTBase(Dataset):
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
):
) -> Dict[str, Any]:
image_array, label_array = data
image: Union[torch.Tensor, io.BytesIO]
......@@ -138,14 +140,14 @@ class _MNISTBase(Dataset):
labels_dp = Decompressor(labels_dp)
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp: IterDataPipe = Zipper(images_dp, labels_dp)
dp = Zipper(images_dp, labels_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
class MNIST(_MNISTBase):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"mnist",
type=DatasetType.RAW,
......@@ -176,7 +178,7 @@ class MNIST(_MNISTBase):
class FashionMNIST(MNIST):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"fashionmnist",
type=DatasetType.RAW,
......@@ -209,7 +211,7 @@ class FashionMNIST(MNIST):
class KMNIST(MNIST):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"kmnist",
type=DatasetType.RAW,
......@@ -231,7 +233,7 @@ class KMNIST(MNIST):
class EMNIST(_MNISTBase):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"emnist",
type=DatasetType.RAW,
......@@ -295,7 +297,7 @@ class EMNIST(_MNISTBase):
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
):
) -> Dict[str, Any]:
image_array, label_array = data
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
......@@ -321,7 +323,7 @@ class EMNIST(_MNISTBase):
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
functools.partial(self._classify_archive, config=config), # type:ignore[arg-type]
functools.partial(self._classify_archive, config=config),
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
......@@ -330,7 +332,7 @@ class EMNIST(_MNISTBase):
class QMNIST(_MNISTBase):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"qmnist",
type=DatasetType.RAW,
......@@ -381,7 +383,7 @@ class QMNIST(_MNISTBase):
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
):
) -> Dict[str, Any]:
image_array, label_array = data
label_parts = label_array.tolist()
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder)
......
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import numpy as np
import torch
......@@ -135,7 +135,7 @@ class SBD(Dataset):
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive, # type: ignore[arg-type]
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
......@@ -159,15 +159,21 @@ class SBD(Dataset):
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m"))
dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp)
dp: IterDataPipe = Mapper(dp, bytes.decode, input_col=1)
dp = Mapper(dp, bytes.decode, input_col=1)
lines = tuple(zip(*iter(dp)))[1]
pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)")
categories_and_labels = [
pattern.match(line).groups() # type: ignore[union-attr]
# the first and last line contain no information
for line in lines[1:-1]
]
return tuple(zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))))[0]
categories_and_labels = cast(
List[Tuple[str, ...]],
[
pattern.match(line).groups() # type: ignore[union-attr]
# the first and last line contain no information
for line in lines[1:-1]
],
)
categories_and_labels.sort(key=lambda category_and_label: int(category_and_label[1]))
categories, _ = zip(*categories_and_labels)
return categories
......@@ -92,7 +92,11 @@ class VOC(Dataset):
return torch.tensor(bboxes)
def _collate_and_decode_sample(
self, data, *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
......@@ -104,7 +108,7 @@ class VOC(Dataset):
if config.task == "detection":
ann = self._decode_detection_ann(ann_buffer)
else: # config.task == "segmentation":
ann = decoder(ann_buffer) if decoder else ann_buffer
ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment]
return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann)
......@@ -120,15 +124,13 @@ class VOC(Dataset):
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
functools.partial(self._classify_archive, config=config), # type: ignore[arg-type]
functools.partial(self._classify_archive, config=config),
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
split_dp: IterDataPipe = Filter(
split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task])
)
split_dp: IterDataPipe = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
......
......@@ -25,7 +25,7 @@ def _collate_and_decode_data(
*,
root: pathlib.Path,
categories: List[str],
decoder,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
path, buffer = data
data = decoder(buffer) if decoder else buffer
......
# type: ignore
import argparse
import collections.abc
import contextlib
......
import io
from typing import cast
import PIL.Image
import torch
......@@ -12,4 +13,4 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))
return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())))
# type: ignore
import argparse
import csv
import sys
......@@ -50,7 +52,7 @@ def parse_args(argv=None):
if __name__ == "__main__":
args = parse_args()
args = parse_args(["-f", "sbd"])
try:
main(*args.names, force=args.force)
......
......@@ -3,16 +3,7 @@ import csv
import enum
import io
import pathlib
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Union,
Tuple,
)
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
import torch
from torch.utils.data import IterDataPipe
......
......@@ -9,7 +9,6 @@ import os
import os.path
import pathlib
import textwrap
from collections.abc import Mapping
from typing import (
Collection,
Sequence,
......@@ -23,13 +22,14 @@ from typing import (
Optional,
NoReturn,
Iterable,
Mapping,
)
from typing import cast
import numpy as np
import PIL.Image
from torch.utils.data import IterDataPipe
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
......@@ -83,7 +83,7 @@ def add_suggestion(
return f"{msg.strip()} {hint}"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
......@@ -101,29 +101,29 @@ def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
return f"{prefix}\n{body}\n{postfix}"
class FrozenMapping(Mapping):
def __init__(self, *args, **kwargs):
class FrozenMapping(Mapping[K, D]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, name: str) -> Any:
return self.__dict__["__data__"][name]
def __getitem__(self, item: K) -> D:
return cast(Mapping[K, D], self.__dict__["__data__"])[item]
def __iter__(self):
def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys())
def __len__(self):
def __len__(self) -> int:
return len(self.__dict__["__data__"])
def __setitem__(self, key: Any, value: Any) -> NoReturn:
def __setitem__(self, key: K, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: Any) -> NoReturn:
def __delitem__(self, key: K) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return self.__dict__["__final_hash__"]
return cast(int, self.__dict__["__final_hash__"])
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
......@@ -131,7 +131,7 @@ class FrozenMapping(Mapping):
return hash(self) == hash(other)
def __repr__(self):
def __repr__(self) -> str:
return repr(self.__dict__["__data__"])
......@@ -205,7 +205,7 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
def getitem(*items: Any) -> Callable[[Any], Any]:
def wrapper(obj: Any):
def wrapper(obj: Any) -> Any:
for item in items:
obj = obj[item]
return obj
......@@ -218,7 +218,7 @@ def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[
name = getter
def getter(path: pathlib.Path) -> D:
return getattr(path, name)
return cast(D, getattr(path, name))
def wrapper(data: Tuple[str, Any]) -> D:
return getter(pathlib.Path(data[0])) # type: ignore[operator]
......
......@@ -8,7 +8,7 @@ from torch.utils.data.datapipes.iter import FileLoader, IterableWrapper
# FIXME
def compute_sha256(_) -> str:
def compute_sha256(path: pathlib.Path) -> str:
return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment