Unverified Commit 4d00ae0f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download functionality to prototype datasets (#5035)

* add download functionality to prototype datasets

* fix annotation

* fix test

* remove iopath

* add comments
parent 4282c9fc
......@@ -29,6 +29,19 @@ __all__ = ["load"]
DEFAULT_TEST_DECODER = object()
class TestResource(datasets.utils.OnlineResource):
def __init__(self, *, dataset_name, dataset_config, **kwargs):
super().__init__(**kwargs)
self.dataset_name = dataset_name
self.dataset_config = dataset_config
def _download(self, _):
raise pytest.UsageError(
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
f"but this file does not exist."
)
class DatasetMocks:
def __init__(self):
self._mock_data_fns = {}
......@@ -72,7 +85,7 @@ class DatasetMocks:
)
return mock_info
def _get(self, dataset, config):
def _get(self, dataset, config, root):
name = dataset.info.name
resources_and_mock_info = self._cache.get((name, config))
if resources_and_mock_info:
......@@ -87,20 +100,12 @@ class DatasetMocks:
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
)
root = self._tmp_home / name
root.mkdir(exist_ok=True)
mock_resources = [
TestResource(dataset_name=name, dataset_config=config, file_name=resource.file_name)
for resource in dataset.resources(config)
]
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name)
mock_resources = []
for resource in dataset.resources(config):
path = root / resource.file_name
if not path.exists() and path.is_file():
raise pytest.UsageError(
f"Dataset '{name}' requires the file {path.name} for {config}, but this file does not exist."
)
mock_resources.append(datasets.utils.LocalResource(path))
self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info
......@@ -109,9 +114,13 @@ class DatasetMocks:
) -> Tuple[IterDataPipe, Dict[str, Any]]:
dataset = find(name)
config = dataset.info.make_config(split=split, **options)
resources, mock_info = self._get(dataset, config)
root = self._tmp_home / name
root.mkdir(exist_ok=True)
resources, mock_info = self._get(dataset, config, root)
datapipe = dataset._make_datapipe(
[resource.to_datapipe() for resource in resources],
[resource.load(root) for resource in resources],
config=config,
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
)
......
......@@ -211,10 +211,10 @@ class TestDataset:
pytest.param(make_minimal_dataset_info().default_config, None, id="default"),
],
)
def test_to_datapipe_config(self, config, kwarg):
def test_load_config(self, config, kwarg):
dataset = self.DatasetMock()
dataset.to_datapipe("", config=kwarg)
dataset.load("", config=kwarg)
dataset.resources.assert_called_with(config)
......@@ -225,18 +225,19 @@ class TestDataset:
dependency = "fake_dependency"
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
with pytest.raises(ModuleNotFoundError, match=dependency):
dataset.to_datapipe("root")
dataset.load("root")
def test_resources(self, mocker):
resource_mock = mocker.Mock(spec=["to_datapipe"])
resource_mock = mocker.Mock(spec=["load"])
sentinel = object()
resource_mock.to_datapipe.return_value = sentinel
resource_mock.load.return_value = sentinel
dataset = self.DatasetMock(resources=[resource_mock])
root = "root"
dataset.to_datapipe(root)
dataset.load(root)
resource_mock.to_datapipe.assert_called_with(root)
(call_args, _) = resource_mock.load.call_args
assert call_args[0] == root
(call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel
......@@ -245,7 +246,7 @@ class TestDataset:
dataset = self.DatasetMock()
sentinel = object()
dataset.to_datapipe("", decoder=sentinel)
dataset.load("", decoder=sentinel)
(_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["decoder"] is sentinel
......@@ -61,16 +61,16 @@ def load(
name: str,
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
skip_integrity_check: bool = False,
split: str = "train",
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
name = name.lower()
dataset = find(name)
if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
config = dataset.info.make_config(split=split, **options)
root = os.path.join(home(), name)
root = os.path.join(home(), dataset.name)
return dataset.to_datapipe(root, config=config, decoder=decoder)
return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
TarArchiveReader,
Shuffler,
Filter,
IterKeyZipper,
......@@ -38,6 +37,7 @@ class Caltech101(Dataset):
images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
decompress=True,
)
anns = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
......@@ -119,11 +119,9 @@ class Caltech101(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = TarArchiveReader(anns_dp)
anns_dp = Filter(anns_dp, self._is_ann)
dp = IterKeyZipper(
......@@ -137,8 +135,7 @@ class Caltech101(Dataset):
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
......@@ -185,13 +182,11 @@ class Caltech256(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
return [name.split(".")[1] for name in sorted(dir_names)]
......@@ -8,7 +8,6 @@ from torchdata.datapipes.iter import (
Mapper,
Shuffler,
Filter,
ZipArchiveReader,
Zipper,
IterKeyZipper,
)
......@@ -154,8 +153,6 @@ class CelebA(Dataset):
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = ZipArchiveReader(images_dp)
anns_dp = Zipper(
*[
CelebACSVParser(dp, fieldnames=fieldnames)
......
......@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
IterDataPipe,
Filter,
Mapper,
TarArchiveReader,
Shuffler,
)
from torchvision.prototype.datasets.decoder import raw
......@@ -85,7 +84,6 @@ class _CifarBase(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
......@@ -93,8 +91,7 @@ class _CifarBase(Dataset):
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
......
......@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
Shuffler,
Filter,
Demultiplexer,
ZipArchiveReader,
Grouper,
IterKeyZipper,
JsonParser,
......@@ -180,13 +179,10 @@ class Coco(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps
images_dp = ZipArchiveReader(images_dp)
if config.annotations is None:
dp = Shuffler(images_dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = ZipArchiveReader(meta_dp)
meta_dp = Filter(
meta_dp,
self._filter_meta_files,
......@@ -234,8 +230,7 @@ class Coco(Dataset):
config = self.default_config
resources = self.resources(config)
dp = resources[1].to_datapipe(pathlib.Path(root) / self.name)
dp = ZipArchiveReader(dp)
dp = resources[1].load(pathlib.Path(root) / self.name)
dp = Filter(
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
)
......
......@@ -9,8 +9,8 @@ from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
ManualDownloadResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
......@@ -25,6 +25,11 @@ from torchvision.prototype.features import Label, DEFAULT
from torchvision.prototype.utils._internal import FrozenMapping
class ImageNetResource(ManualDownloadResource):
def __init__(self, **kwargs: Any) -> None:
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetLabel(Label):
wnid: Optional[str]
......@@ -81,10 +86,10 @@ class ImageNet(Dataset):
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
name = "test_v10102019" if config.split == "test" else config.split
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
images = ImageNetResource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
devkit = HttpResource(
"ILSVRC2012_devkit_t12.tar.gz",
devkit = ImageNetResource(
file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
)
......@@ -139,15 +144,12 @@ class ImageNet(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
images_dp, devkit_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
if config.split == "train":
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
elif config.split == "val":
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int)
......@@ -177,8 +179,7 @@ class ImageNet(Dataset):
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = resources[1].load(root / self.name)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
meta = next(iter(devkit_dp))[1]
......
......@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
IterDataPipe,
Demultiplexer,
Mapper,
ZipArchiveReader,
Zipper,
Shuffler,
)
......@@ -310,7 +309,6 @@ class EMNIST(_MNISTBase):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
archive_dp = ZipArchiveReader(archive_dp)
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
......
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
TarArchiveReader,
Shuffler,
Demultiplexer,
Filter,
......@@ -129,7 +128,6 @@ class SBD(Dataset):
archive_dp, extra_split_dp = resource_dps
archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
......@@ -155,8 +153,7 @@ class SBD(Dataset):
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp)
dp = Mapper(dp, bytes.decode, input_col=1)
......
......@@ -30,11 +30,11 @@ class SEMEION(Dataset):
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
data = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
)
return [archive]
return [data]
def _collate_and_decode_sample(
self,
......
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
TarArchiveReader,
Shuffler,
Filter,
Demultiplexer,
......@@ -119,7 +118,6 @@ class VOC(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
......
......@@ -2,6 +2,7 @@
import argparse
import csv
import pathlib
import sys
from torchvision.prototype import datasets
......@@ -10,7 +11,7 @@ from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
root = datasets.home()
root = pathlib.Path(datasets.home())
for name in names:
path = BUILTIN_DIR / f"{name}.categories"
......@@ -24,7 +25,8 @@ def main(*names, force=False):
continue
with open(path, "w", newline="") as file:
csv.writer(file).writerows(categories)
for category in categories:
csv.writer(file).writerow((category,) if isinstance(category, str) else category)
def parse_args(argv=None):
......
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource
......@@ -172,12 +172,13 @@ class Dataset(abc.ABC):
def supports_sharded(self) -> bool:
return False
def to_datapipe(
def load(
self,
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
skip_integrity_check: bool = False,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
config = self.info.default_config
......@@ -188,7 +189,9 @@ class Dataset(abc.ABC):
return _make_sharded_datapipe(root, dataset_size)
self.info.check_dependencies()
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
resource_dps = [
resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
......
import os.path
import abc
import hashlib
import itertools
import pathlib
from typing import Optional, Union
import warnings
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from urllib.parse import urlparse
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IoPathFileLoader
from torchdata.datapipes.iter import (
IterableWrapper,
FileLister,
FileLoader,
IterDataPipe,
ZipArchiveReader,
TarArchiveReader,
RarArchiveLoader,
)
from torchvision.datasets.utils import (
download_url,
_detect_file_type,
extract_archive,
_decompress,
download_file_from_google_drive,
)
# FIXME
def compute_sha256(path: pathlib.Path) -> str:
return ""
class OnlineResource(abc.ABC):
def __init__(
self,
*,
file_name: str,
sha256: Optional[str] = None,
decompress: bool = False,
extract: bool = False,
preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None,
loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
if preprocess and (decompress or extract):
warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.")
elif extract:
preprocess = self._extract
elif decompress:
preprocess = self._decompress
self._preprocess = preprocess
class LocalResource:
def __init__(self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None) -> None:
self.path = pathlib.Path(path).expanduser().resolve()
self.file_name = self.path.name
self.sha256 = sha256 or compute_sha256(self.path)
if loader is None:
loader = self._default_loader
self._loader = loader
def to_datapipe(self) -> IterDataPipe:
return IoPathFileLoader(IterableWrapper((str(self.path),)), mode="rb") # type: ignore
@staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
)
@staticmethod
def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True))
class OnlineResource:
def __init__(self, url: str, *, sha256: str, file_name: str) -> None:
self.url = url
self.sha256 = sha256
self.file_name = file_name
def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
return FileLoader(FileLister(str(path), recursive=True))
dp = FileLoader(IterableWrapper((str(path),)))
archive_loader = self._guess_archive_loader(path)
if archive_loader:
dp = archive_loader(dp)
return dp
_ARCHIVE_LOADERS = {
".tar": TarArchiveReader,
".zip": ZipArchiveReader,
".rar": RarArchiveLoader,
}
def _guess_archive_loader(
self, path: pathlib.Path
) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]:
try:
_, archive_type, _ = _detect_file_type(path.name)
except RuntimeError:
return None
return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type]
def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe:
path = os.path.join(root, self.file_name)
# FIXME
return IoPathFileLoader(IterableWrapper((str(path),)), mode="rb") # type: ignore
def load(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False
) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root)
path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all. Thus, we look for all paths that share the same name without suffixes as the raw
# file.
path_candidates = {file for file in path.parent.glob(path.name.replace("".join(path.suffixes), "") + "*")}
# If we don't find anything, we try to download the raw file.
if not path_candidates:
path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if path_candidates == {path}:
if self._preprocess:
path = self._preprocess(path)
# Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority
# that we want.
else:
path = min(path_candidates, key=lambda path: len(path.suffixes))
return self._loader(path)
@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
pass
def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
root = pathlib.Path(root)
self._download(root)
path = root / self.file_name
if self.sha256 and not skip_integrity_check:
self._check_sha256(path)
return path
def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None:
hash = hashlib.sha256()
with open(path, "rb") as file:
for chunk in iter(lambda: file.read(chunk_size), b""):
hash.update(chunk)
sha256 = hash.hexdigest()
if sha256 != self.sha256:
raise RuntimeError(
f"After the download, the SHA256 checksum of {path} didn't match the expected one: "
f"{sha256} != {self.sha256}"
)
# TODO: add support for mirrors
# TODO: add support for http -> https
class HttpResource(OnlineResource):
def __init__(self, url: str, *, sha256: str, file_name: Optional[str] = None) -> None:
if not file_name:
file_name = os.path.basename(urlparse(url).path)
super().__init__(url, sha256=sha256, file_name=file_name)
def __init__(
self, url: str, *, file_name: Optional[str] = None, mirrors: Optional[Sequence[str]] = None, **kwargs: Any
) -> None:
super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)
self.url = url
self.mirrors = mirrors
def _download(self, root: pathlib.Path) -> None:
for url in itertools.chain((self.url,), self.mirrors or ()):
try:
download_url(url, str(root), filename=self.file_name, md5=None)
# TODO: make this more precise
except Exception:
continue
return
else:
# TODO: make this more informative
raise RuntimeError("Download failed!")
class GDriveResource(OnlineResource):
def __init__(self, id: str, *, sha256: str, file_name: str) -> None:
# TODO: can we maybe do a head request to extract the file name?
url = f"https://drive.google.com/file/d/{id}/view"
super().__init__(url, sha256=sha256, file_name=file_name)
def __init__(self, id: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.id = id
def _download(self, root: pathlib.Path) -> None:
download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None)
class ManualDownloadResource(OnlineResource):
def __init__(self, instructions: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.instructions = instructions
def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError(
f"The file {self.file_name} cannot be downloaded automatically. "
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment