Unverified Commit b430ba68 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify OnlineResource.load (#5990)

* simplify OnlineResource.load

* [PoC] merge mock data preparation and loading

* Revert "cache mock data based on config"

This reverts commit 5ed6eedef74865e0baa746a375d5ec1f0ab1bde7.

* Revert "[PoC] merge mock data preparation and loading"

This reverts commit d62747962f9ed6a7b0b80849e7c971efabb5d3da.

* remove preprocess returning a new path in favor of querying twice

* address test comments

* clarify comment

* mypy

* use builtin decompress utility
parent 08c8f0e0
import gzip
import pathlib
import sys
import numpy as np
import pytest
import torch
from datasets_utils import make_fake_flo_file
from datasets_utils import make_fake_flo_file, make_tar
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.datasets.utils import _decompress
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
......@@ -48,6 +52,183 @@ def test_read_flo(tmpdir):
torch.testing.assert_close(actual, expected)
class TestOnlineResource:
class DummyResource(OnlineResource):
def __init__(self, download_fn=None, **kwargs):
super().__init__(**kwargs)
self._download_fn = download_fn
def _download(self, root):
if self._download_fn is None:
raise pytest.UsageError(
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`."
)
return self._download_fn(self, root)
def _make_file(self, root, *, content, name="file.txt"):
file = root / name
with open(file, "w") as fh:
fh.write(content)
return file
def _make_folder(self, root, *, name="folder"):
folder = root / name
subfolder = folder / "subfolder"
subfolder.mkdir(parents=True)
files = {}
for idx, root in enumerate([folder, folder, subfolder]):
content = f"sentinel{idx}"
file = self._make_file(root, name=f"file{idx}.txt", content=content)
files[str(file)] = content
return folder, files
def _make_tar(self, root, *, name="archive.tar", remove=True):
folder, files = self._make_folder(root, name=name.split(".")[0])
archive = make_tar(root, name, folder, remove=remove)
files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()}
return archive, files
def test_load_file(self, tmp_path):
content = "sentinel"
file = self._make_file(tmp_path, content=content)
resource = self.DummyResource(file_name=file.name)
dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
data = list(dp)
assert len(data) == 1
path, buffer = data[0]
assert path == str(file)
assert buffer.read().decode() == content
def test_load_folder(self, tmp_path):
folder, files = self._make_folder(tmp_path)
resource = self.DummyResource(file_name=folder.name)
dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
assert {path: buffer.read().decode() for path, buffer in dp} == files
def test_load_archive(self, tmp_path):
archive, files = self._make_tar(tmp_path)
resource = self.DummyResource(file_name=archive.name)
dp = resource.load(tmp_path)
assert isinstance(dp, TarArchiveLoader)
assert {path: buffer.read().decode() for path, buffer in dp} == files
def test_priority_decompressed_gt_raw(self, tmp_path):
# We don't need to actually compress here. Adding the suffix is sufficient
self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz")
file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt")
resource = self.DummyResource(file_name=file.name)
dp = resource.load(tmp_path)
path, buffer = next(iter(dp))
assert path == str(file)
assert buffer.read().decode() == "decompressed_sentinel"
def test_priority_extracted_gt_decompressed(self, tmp_path):
archive, _ = self._make_tar(tmp_path, remove=False)
resource = self.DummyResource(file_name=archive.name)
dp = resource.load(tmp_path)
# If the archive had been selected, this would be a `TarArchiveReader`
assert isinstance(dp, FileOpener)
def test_download(self, tmp_path):
download_fn_was_called = False
def download_fn(resource, root):
nonlocal download_fn_was_called
download_fn_was_called = True
return self._make_file(root, content="_", name=resource.file_name)
resource = self.DummyResource(
file_name="file.txt",
download_fn=download_fn,
)
resource.load(tmp_path)
assert download_fn_was_called, "`download_fn()` was never called"
# This tests the `"decompress"` literal as well as a custom callable
@pytest.mark.parametrize(
"preprocess",
[
"decompress",
lambda path: _decompress(str(path), remove_finished=True),
],
)
def test_preprocess_decompress(self, tmp_path, preprocess):
file_name = "file.txt.gz"
content = "sentinel"
def download_fn(resource, root):
file = root / resource.file_name
with gzip.open(file, "wb") as fh:
fh.write(content.encode())
return file
resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn)
dp = resource.load(tmp_path)
data = list(dp)
assert len(data) == 1
path, buffer = data[0]
assert path == str(tmp_path / file_name).replace(".gz", "")
assert buffer.read().decode() == content
def test_preprocess_extract(self, tmp_path):
files = None
def download_fn(resource, root):
nonlocal files
archive, files = self._make_tar(root, name=resource.file_name)
return archive
resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn)
dp = resource.load(tmp_path)
assert files is not None, "`download_fn()` was never called"
assert isinstance(dp, FileOpener)
actual = {path: buffer.read().decode() for path, buffer in dp}
expected = {
path.replace(resource.file_name, resource.file_name.split(".")[0]): content
for path, content in files.items()
}
assert actual == expected
def test_preprocess_only_after_download(self, tmp_path):
file = self._make_file(tmp_path, content="_")
def preprocess(path):
raise AssertionError("`preprocess` was called although the file was already present.")
resource = self.DummyResource(
file_name=file.name,
preprocess=preprocess,
)
resource.load(tmp_path)
class TestHttpResource:
def test_resolve_to_http(self, mocker):
file_name = "data.tar"
......
......@@ -2,7 +2,7 @@ import abc
import hashlib
import itertools
import pathlib
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set
from urllib.parse import urlparse
from torchdata.datapipes.iter import (
......@@ -32,7 +32,7 @@ class OnlineResource(abc.ABC):
*,
file_name: str,
sha256: Optional[str] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
......@@ -50,14 +50,12 @@ class OnlineResource(abc.ABC):
self._preprocess = preprocess
@staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(
def _extract(file: pathlib.Path) -> None:
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
)
@staticmethod
def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True))
def _decompress(file: pathlib.Path) -> None:
_decompress(str(file), remove_finished=True)
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
......@@ -91,32 +89,38 @@ class OnlineResource(abc.ABC):
) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root)
path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all.
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
stem = path.name.replace("".join(path.suffixes), "")
# In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since
# extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive
# is always extracted in a folder with the corresponding file name.
def find_candidates() -> Set[pathlib.Path]:
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
# test split of the stanford-cars dataset uses the files
# - cars_test.tgz
# - cars_test_annos_withlabels.mat
# Globbing for `"cars_test*"` picks up both.
candidates = {file for file in path.parent.glob(f"{stem}.*")}
folder_candidate = path.parent / stem
if folder_candidate.exists() and folder_candidate.is_dir():
return self._loader(folder_candidate)
# If there is no folder, we look for all files that share the same stem as the raw file, but might have a
# different suffix.
file_candidates = {file for file in path.parent.glob(stem + ".*")}
# If we don't find anything, we download the raw file.
if not file_candidates:
file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if file_candidates == {path}:
if folder_candidate.exists():
candidates.add(folder_candidate)
return candidates
candidates = find_candidates()
if not candidates:
self.download(root, skip_integrity_check=skip_integrity_check)
if self._preprocess is not None:
path = self._preprocess(path)
# Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we
# want for the best I/O performance.
else:
path = min(file_candidates, key=lambda path: len(path.suffixes))
return self._loader(path)
self._preprocess(path)
candidates = find_candidates()
# We use the path with the fewest suffixes. This gives us the
# extracted > decompressed > raw
# priority that we want for the best I/O performance.
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))
@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment