"vscode:/vscode.git/clone" did not exist on "1e853e240ed6c33e5c43221dbbe7ef3dc577cf37"
Unverified Commit 49ec677c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for prototype builtin datasets (#4682)

* add tests for builtin prototype datasets

* fix caltech101

* fix emnist

* fix mnist and variants

* add iopath as test requirement

* fix MNIST warning

* fix qmnist data generation

* fix cifar data generation

* add tests for imagenet

* cleanup
parent 1cbd9cde
......@@ -273,8 +273,8 @@ jobs:
name: Install torchvision
command: pip install --user --progress-bar off --no-build-isolation .
- run:
name: Install test utilities
command: pip install --user --progress-bar=off pytest pytest-mock
name: Install test requirements
command: pip install --user --progress-bar=off pytest pytest-mock scipy iopath
- run:
name: Run tests
command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py
......
......@@ -273,8 +273,8 @@ jobs:
name: Install torchvision
command: pip install --user --progress-bar off --no-build-isolation .
- run:
name: Install test utilities
command: pip install --user --progress-bar=off pytest pytest-mock
name: Install test requirements
command: pip install --user --progress-bar=off pytest pytest-mock scipy iopath
- run:
name: Run tests
command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py
......
import functools
import gzip
import lzma
import pathlib
import pickle
import tempfile
from collections import defaultdict
from typing import Any, Dict, Tuple
import numpy as np
import pytest
import torch
from datasets_utils import create_image_folder, make_tar, make_zip
from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import add_suggestion
make_tensor = functools.partial(_make_tensor, device="cpu")
__all__ = ["load"]
DEFAULT_TEST_DECODER = object()
class DatasetMocks:
def __init__(self):
self._mock_data_fns = {}
self._tmp_home = pathlib.Path(tempfile.mkdtemp())
self._cache = {}
def register_mock_data_fn(self, mock_data_fn):
name = mock_data_fn.__name__
if name not in datasets.list():
raise pytest.UsageError(
add_suggestion(
f"The name of the mock data function '{name}' has no corresponding dataset.",
word=name,
possibilities=datasets.list(),
close_match_hint=lambda close_match: f"Did you mean to name it '{close_match}'?",
alternative_hint=lambda _: "",
)
)
self._mock_data_fns[name] = mock_data_fn
return mock_data_fn
def _parse_mock_info(self, mock_info, *, name):
if mock_info is None:
raise pytest.UsageError(
f"The mock data function for dataset '{name}' returned nothing. It needs to at least return an integer "
f"indicating the number of samples for the current `config`."
)
elif isinstance(mock_info, int):
mock_info = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{name}' returned a {type(mock_info)}. The returned object should "
f"be a dictionary containing at least the number of samples for the current `config` for the key "
f"`'num_samples'`. If no additional information is required for specific tests, the number of samples "
f"can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{name}' must contain a `'num_samples'` "
f"entry indicating the number of samples for the current `config`."
)
return mock_info
def _get(self, dataset, config):
name = dataset.info.name
resources_and_mock_info = self._cache.get((name, config))
if resources_and_mock_info:
return resources_and_mock_info
try:
fakedata_fn = self._mock_data_fns[name]
except KeyError:
raise pytest.UsageError(
f"No mock data available for dataset '{name}'. "
f"Did you add a new dataset, but forget to provide mock data for it? "
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
)
root = self._tmp_home / name
root.mkdir(exist_ok=True)
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name)
mock_resources = []
for resource in dataset.resources(config):
path = root / resource.file_name
if not path.exists() and path.is_file():
raise pytest.UsageError(
f"Dataset '{name}' requires the file {path.name} for {config}, but this file does not exist."
)
mock_resources.append(datasets.utils.LocalResource(path))
self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info
def _decoder(self, dataset_type):
if dataset_type == datasets.utils.DatasetType.RAW:
return datasets.decoder.raw
else:
return lambda file: file.close()
def load(
self, name: str, decoder=DEFAULT_TEST_DECODER, split="train", **options: Any
) -> Tuple[IterDataPipe, Dict[str, Any]]:
dataset = find(name)
config = dataset.info.make_config(split=split, **options)
resources, mock_info = self._get(dataset, config)
datapipe = dataset._make_datapipe(
[resource.to_datapipe() for resource in resources],
config=config,
decoder=self._decoder(dataset.info.type) if decoder is DEFAULT_TEST_DECODER else decoder,
)
return datapipe, mock_info
dataset_mocks = DatasetMocks()
load = dataset_mocks.load
class MNISTFakedata:
_DTYPES_ID = {
torch.uint8: 8,
torch.int8: 9,
torch.int16: 11,
torch.int32: 12,
torch.float32: 13,
torch.float64: 14,
}
@classmethod
def _magic(cls, dtype, ndim):
return cls._DTYPES_ID[dtype] * 256 + ndim + 1
@staticmethod
def _encode(t):
return torch.tensor(t, dtype=torch.int32).numpy().tobytes()[::-1]
@staticmethod
def _big_endian_dtype(dtype):
np_dtype = getattr(np, str(dtype).replace("torch.", ""))().dtype
return np.dtype(f">{np_dtype.kind}{np_dtype.itemsize}")
@classmethod
def _create_binary_file(cls, root, filename, *, num_samples, shape, dtype, compressor, low=0, high):
with compressor(root / filename, "wb") as fh:
if dtype != torch.uint8:
print()
for meta in (cls._magic(dtype, len(shape)), num_samples, *shape):
fh.write(cls._encode(meta))
data = make_tensor((num_samples, *shape), dtype=dtype, low=low, high=high)
fh.write(data.numpy().astype(cls._big_endian_dtype(dtype)).tobytes())
@classmethod
def generate(
cls,
root,
*,
num_categories,
num_samples=None,
images_file,
labels_file,
image_size=(28, 28),
image_dtype=torch.uint8,
label_size=(),
label_dtype=torch.uint8,
compressor=None,
):
if num_samples is None:
num_samples = num_categories
if compressor is None:
compressor = gzip.open
cls._create_binary_file(
root,
images_file,
num_samples=num_samples,
shape=image_size,
dtype=image_dtype,
compressor=compressor,
high=float("inf"),
)
cls._create_binary_file(
root,
labels_file,
num_samples=num_samples,
shape=label_size,
dtype=label_dtype,
compressor=compressor,
high=num_categories,
)
return num_samples
@dataset_mocks.register_mock_data_fn
def mnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
@dataset_mocks.register_mock_data_fn
def fashionmnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
@dataset_mocks.register_mock_data_fn
def kmnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
@dataset_mocks.register_mock_data_fn
def emnist(info, root, config):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
num_categories = defaultdict(
lambda: len(info.categories), **{image_set: 47 for image_set in ("Balanced", "By_Merge")}
)
num_samples = {}
file_names = set()
for _config in info._configs:
prefix = f"emnist-{_config.image_set.replace('_', '').lower()}-{_config.split}"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file})
num_samples[_config.image_set] = MNISTFakedata.generate(
root,
num_categories=num_categories[_config.image_set],
images_file=images_file,
labels_file=labels_file,
)
make_zip(root, "emnist-gzip.zip", *file_names)
return num_samples[config.image_set]
@dataset_mocks.register_mock_data_fn
def qmnist(info, root, config):
num_categories = len(info.categories)
if config.split == "train":
num_samples = num_samples_gen = num_categories + 2
prefix = "qmnist-train"
suffix = ".gz"
compressor = gzip.open
elif config.split.startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create more
# than 10000 images for the dataset to not be empty.
num_samples = num_samples_gen = 10001
if config.split == "test10k":
num_samples = min(num_samples, 10000)
if config.split == "test50k":
num_samples -= 10000
prefix = "qmnist-test"
suffix = ".gz"
compressor = gzip.open
else: # config.split == "nist"
num_samples = num_samples_gen = num_categories + 3
prefix = "xnist"
suffix = ".xz"
compressor = lzma.open
MNISTFakedata.generate(
root,
num_categories=num_categories,
num_samples=num_samples_gen,
images_file=f"{prefix}-images-idx3-ubyte{suffix}",
labels_file=f"{prefix}-labels-idx2-int{suffix}",
label_size=(8,),
label_dtype=torch.int32,
compressor=compressor,
)
return num_samples
class CIFARFakedata:
NUM_PIXELS = 32 * 32 * 3
@classmethod
def _create_batch_file(cls, root, name, *, num_categories, labels_key, num_samples=1):
content = {
"data": make_tensor((num_samples, cls.NUM_PIXELS), dtype=torch.uint8).numpy(),
labels_key: torch.randint(0, num_categories, size=(num_samples,)).tolist(),
}
with open(pathlib.Path(root) / name, "wb") as fh:
pickle.dump(content, fh)
@classmethod
def generate(
cls,
root,
name,
*,
folder,
train_files,
test_files,
num_categories,
labels_key,
):
folder = root / folder
folder.mkdir()
files = (*train_files, *test_files)
for file in files:
cls._create_batch_file(
folder,
file,
num_categories=num_categories,
labels_key=labels_key,
)
make_tar(root, name, folder, compression="gz")
@dataset_mocks.register_mock_data_fn
def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"]
CIFARFakedata.generate(
root=root,
name="cifar-10-python.tar.gz",
folder=pathlib.Path("cifar-10-batches-py"),
train_files=train_files,
test_files=test_files,
num_categories=10,
labels_key="labels",
)
return len(train_files if config.split == "train" else test_files)
@dataset_mocks.register_mock_data_fn
def cifar100(info, root, config):
train_files = ["train"]
test_files = ["test"]
CIFARFakedata.generate(
root=root,
name="cifar-100-python.tar.gz",
folder=pathlib.Path("cifar-100-python"),
train_files=train_files,
test_files=test_files,
num_categories=100,
labels_key="fine_labels",
)
return len(train_files if config.split == "train" else test_files)
@dataset_mocks.register_mock_data_fn
def caltech101(info, root, config):
def create_ann_file(root, name):
import scipy.io
box_coord = make_tensor((1, 4), dtype=torch.int32, low=0).numpy().astype(np.uint16)
obj_contour = make_tensor((2, int(torch.randint(3, 6, size=()))), dtype=torch.float64, low=0).numpy()
scipy.io.savemat(str(pathlib.Path(root) / name), dict(box_coord=box_coord, obj_contour=obj_contour))
def create_ann_folder(root, name, file_name_fn, num_examples):
root = pathlib.Path(root) / name
root.mkdir(parents=True)
for idx in range(num_examples):
create_ann_file(root, file_name_fn(idx))
images_root = root / "101_ObjectCategories"
anns_root = root / "Annotations"
ann_category_map = {
"Faces_2": "Faces",
"Faces_3": "Faces_easy",
"Motorbikes_16": "Motorbikes",
"Airplanes_Side_2": "airplanes",
}
num_images_per_category = 2
for category in info.categories:
create_image_folder(
root=images_root,
name=category,
file_name_fn=lambda idx: f"image_{idx + 1:04d}.jpg",
num_examples=num_images_per_category,
)
create_ann_folder(
root=anns_root,
name=ann_category_map.get(category, category),
file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat",
num_examples=num_images_per_category,
)
(images_root / "BACKGROUND_Goodle").mkdir()
make_tar(root, f"{images_root.name}.tar.gz", images_root, compression="gz")
make_tar(root, f"{anns_root.name}.tar", anns_root)
return num_images_per_category * len(info.categories)
@dataset_mocks.register_mock_data_fn
def caltech256(info, root, config):
dir = root / "256_ObjectCategories"
num_images_per_category = 2
for idx, category in enumerate(info.categories, 1):
files = create_image_folder(
dir,
name=f"{idx:03d}.{category}",
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg",
num_examples=num_images_per_category,
)
if category == "spider":
open(files[0].parent / "RENAME2", "w").close()
make_tar(root, f"{dir.name}.tar", dir)
return num_images_per_category * len(info.categories)
@dataset_mocks.register_mock_data_fn
def imagenet(info, root, config):
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train":
images_root = root / "ILSVRC2012_img_train"
num_samples = len(wnids)
for wnid in wnids:
files = create_image_folder(
root=images_root,
name=wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
num_examples=1,
)
make_tar(images_root, f"{wnid}.tar", files[0].parent)
else:
num_samples = 3
files = create_image_folder(
root=root,
name="ILSVRC2012_img_val",
file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
images_root = files[0].parent
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
make_tar(root, f"{images_root.name}.tar", images_root)
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
return num_samples
......@@ -6,9 +6,12 @@ import itertools
import os
import pathlib
import random
import shutil
import string
import tarfile
import unittest
import unittest.mock
import zipfile
from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
......@@ -33,6 +36,8 @@ __all__ = [
"create_image_folder",
"create_video_file",
"create_video_folder",
"make_tar",
"make_zip",
"create_random_string",
]
......@@ -835,6 +840,69 @@ def create_video_folder(
]
def _split_files_or_dirs(root, *files_or_dirs):
files = set()
dirs = set()
for file_or_dir in files_or_dirs:
path = pathlib.Path(file_or_dir)
if not path.is_absolute():
path = root / path
if path.is_file():
files.add(path)
else:
dirs.add(path)
for sub_file_or_dir in path.glob("**/*"):
if sub_file_or_dir.is_file():
files.add(sub_file_or_dir)
else:
dirs.add(sub_file_or_dir)
if root in dirs:
dirs.remove(root)
return files, dirs
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
files, dirs = _split_files_or_dirs(root, *files_or_dirs)
with opener(archive) as fh:
for file in files:
adder(fh, file, file.relative_to(root))
if remove:
for file in files:
os.remove(file)
for dir in dirs:
shutil.rmtree(dir, ignore_errors=True)
return archive
def make_tar(root, name, *files_or_dirs, remove=True, compression=None):
# TODO: detect compression from name
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"),
adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file),
remove=remove,
)
def make_zip(root, name, *files_or_dirs, remove=True):
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: zipfile.ZipFile(archive, "w"),
adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file),
remove=remove,
)
def create_random_string(length: int, *digits: str) -> str:
"""Create a random string.
......
import functools
import io
import builtin_dataset_mocks
import pytest
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import sequence_to_str
_loaders = []
_datasets = []
# TODO: this can be replaced by torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
TMP = [
"mnist",
"fashionmnist",
"kmnist",
"emnist",
"qmnist",
"cifar10",
"cifar100",
"caltech256",
"caltech101",
"imagenet",
]
for name in TMP:
loader = functools.partial(builtin_dataset_mocks.load, name)
_loaders.append(pytest.param(loader, id=name))
info = datasets.info(name)
_datasets.extend(
[
pytest.param(*loader(**config), id=f"{name}-{'-'.join([str(value) for value in config.values()])}")
for config in info._configs
]
)
loaders = pytest.mark.parametrize("loader", _loaders)
builtin_datasets = pytest.mark.parametrize(("dataset", "mock_info"), _datasets)
class TestCommon:
@builtin_datasets
def test_smoke(self, dataset, mock_info):
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@builtin_datasets
def test_sample(self, dataset, mock_info):
try:
sample = next(iter(dataset))
except Exception as error:
raise AssertionError("Drawing a sample raised the error above.") from error
if not isinstance(sample, dict):
raise AssertionError(f"Samples should be dictionaries, but got {type(sample)} instead.")
if not sample:
raise AssertionError("Sample dictionary is empty.")
@builtin_datasets
def test_num_samples(self, dataset, mock_info):
num_samples = 0
for _ in dataset:
num_samples += 1
assert num_samples == mock_info["num_samples"]
@builtin_datasets
def test_decoding(self, dataset, mock_info):
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)
class TestQMNIST:
@pytest.mark.parametrize(
"dataset",
[
pytest.param(builtin_dataset_mocks.load("qmnist", split=split)[0], id=split)
for split in ("train", "test", "test10k", "test50k", "nist")
],
)
def test_extra_label(self, dataset):
sample = next(iter(dataset))
for key, type in (
("nist_hsf_series", int),
("nist_writer_id", int),
("digit_index", int),
("nist_label", int),
("global_digit_index", int),
("duplicate", bool),
("unused", bool),
):
assert key in sample and isinstance(sample[key], type)
......@@ -81,11 +81,11 @@ class Caltech101(Dataset):
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
key, image_data, ann_data = data
key, (image_data, ann_data) = data
category, _ = key
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
......
......@@ -127,7 +127,7 @@ class ImageNet(Dataset):
if config.split == "train":
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
# dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
else:
devkit_dp = TarArchiveReader(devkit_dp)
......
......@@ -56,8 +56,21 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
self.stop = stop
@staticmethod
def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16)
def _decode(input: bytes) -> int:
return int(codecs.encode(input, "hex"), 16)
@staticmethod
def _to_tensor(chunk: bytes, *, dtype: torch.dtype, shape: List[int], reverse_bytes: bool) -> torch.Tensor:
# As is, the chunk is not writeable, because it is read from a file and not from memory. Thus, we copy here to
# avoid the warning that torch.frombuffer would emit otherwise. This also enables inplace operations on the
# contents, which would otherwise fail.
chunk = bytearray(chunk)
if reverse_bytes:
chunk.reverse()
tensor = torch.frombuffer(chunk, dtype=dtype).flip(0)
else:
tensor = torch.frombuffer(chunk, dtype=dtype)
return tensor.reshape(shape)
def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
......@@ -71,21 +84,15 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
reverse_bytes = sys.byteorder == "little" and num_bytes_per_value > 1
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
start = self.start or 0
stop = self.stop or num_samples
stop = min(self.stop, num_samples) if self.stop else num_samples
file.seek(start * chunk_size, 1)
for _ in range(stop - start):
chunk = file.read(chunk_size)
if not needs_byte_reversal:
yield torch.frombuffer(chunk, dtype=dtype).reshape(shape)
chunk = bytearray(chunk)
chunk.reverse()
yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape)
yield self._to_tensor(file.read(chunk_size), dtype=dtype, shape=shape, reverse_bytes=reverse_bytes)
class _MNISTBase(Dataset):
......@@ -308,7 +315,9 @@ class EMNIST(_MNISTBase):
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self.categories. Thus, we need to add 1 to the label to correct this.
if config.image_set in ("Balanced", "By_Merge"):
data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0)
image, label = data
label += self._LABEL_OFFSETS.get(int(label), 0)
data = (image, label)
return super()._collate_and_decode(data, config=config, decoder=decoder)
def _make_datapipe(
......
......@@ -2,6 +2,7 @@ import abc
import csv
import enum
import io
import itertools
import os
import pathlib
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
......@@ -68,18 +69,22 @@ class DatasetInfo:
f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}."
)
self._valid_options: Dict[str, Sequence] = valid_options
self._configs = tuple(
DatasetConfig(**dict(zip(valid_options.keys(), combination)))
for combination in itertools.product(*valid_options.values())
)
self.extra = FrozenBunch(extra or dict())
@property
def default_config(self) -> DatasetConfig:
return self._configs[0]
@staticmethod
def read_categories_file(path: pathlib.Path) -> List[List[str]]:
with open(path, newline="") as file:
return [row for row in csv.reader(file)]
@property
def default_config(self) -> DatasetConfig:
return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()})
def make_config(self, **options: Any) -> DatasetConfig:
for name, arg in options.items():
if name not in self._valid_options:
......
......@@ -35,6 +35,7 @@ import torch.distributed as dist
import torch.utils.data
from torch.utils.data import IterDataPipe
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader
from torchdata.datapipes.utils import StreamWrapper
__all__ = [
......@@ -87,6 +88,9 @@ def add_suggestion(
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
......@@ -172,6 +176,9 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
if isinstance(buffer, StreamWrapper):
buffer = buffer.file_obj
return sio.loadmat(buffer, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment