"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d8ff3193758b2c3aec122deb427b58aef7c54b0a"
Unverified Commit b430ba68 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify OnlineResource.load (#5990)

* simplify OnlineResource.load

* [PoC] merge mock data preparation and loading

* Revert "cache mock data based on config"

This reverts commit 5ed6eedef74865e0baa746a375d5ec1f0ab1bde7.

* Revert "[PoC] merge mock data preparation and loading"

This reverts commit d62747962f9ed6a7b0b80849e7c971efabb5d3da.

* remove preprocess returning a new path in favor of querying twice

* address test comments

* clarify comment

* mypy

* use builtin decompress utility
parent 08c8f0e0
import gzip
import pathlib
import sys import sys
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from datasets_utils import make_fake_flo_file from datasets_utils import make_fake_flo_file, make_tar
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset from torchvision.datasets.utils import _decompress
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
...@@ -48,6 +52,183 @@ def test_read_flo(tmpdir): ...@@ -48,6 +52,183 @@ def test_read_flo(tmpdir):
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
class TestOnlineResource:
class DummyResource(OnlineResource):
def __init__(self, download_fn=None, **kwargs):
super().__init__(**kwargs)
self._download_fn = download_fn
def _download(self, root):
if self._download_fn is None:
raise pytest.UsageError(
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`."
)
return self._download_fn(self, root)
def _make_file(self, root, *, content, name="file.txt"):
file = root / name
with open(file, "w") as fh:
fh.write(content)
return file
def _make_folder(self, root, *, name="folder"):
folder = root / name
subfolder = folder / "subfolder"
subfolder.mkdir(parents=True)
files = {}
for idx, root in enumerate([folder, folder, subfolder]):
content = f"sentinel{idx}"
file = self._make_file(root, name=f"file{idx}.txt", content=content)
files[str(file)] = content
return folder, files
def _make_tar(self, root, *, name="archive.tar", remove=True):
folder, files = self._make_folder(root, name=name.split(".")[0])
archive = make_tar(root, name, folder, remove=remove)
files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()}
return archive, files
def test_load_file(self, tmp_path):
content = "sentinel"
file = self._make_file(tmp_path, content=content)
resource = self.DummyResource(file_name=file.name)
dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
data = list(dp)
assert len(data) == 1
path, buffer = data[0]
assert path == str(file)
assert buffer.read().decode() == content
def test_load_folder(self, tmp_path):
folder, files = self._make_folder(tmp_path)
resource = self.DummyResource(file_name=folder.name)
dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
assert {path: buffer.read().decode() for path, buffer in dp} == files
def test_load_archive(self, tmp_path):
archive, files = self._make_tar(tmp_path)
resource = self.DummyResource(file_name=archive.name)
dp = resource.load(tmp_path)
assert isinstance(dp, TarArchiveLoader)
assert {path: buffer.read().decode() for path, buffer in dp} == files
def test_priority_decompressed_gt_raw(self, tmp_path):
# We don't need to actually compress here. Adding the suffix is sufficient
self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz")
file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt")
resource = self.DummyResource(file_name=file.name)
dp = resource.load(tmp_path)
path, buffer = next(iter(dp))
assert path == str(file)
assert buffer.read().decode() == "decompressed_sentinel"
def test_priority_extracted_gt_decompressed(self, tmp_path):
archive, _ = self._make_tar(tmp_path, remove=False)
resource = self.DummyResource(file_name=archive.name)
dp = resource.load(tmp_path)
# If the archive had been selected, this would be a `TarArchiveReader`
assert isinstance(dp, FileOpener)
def test_download(self, tmp_path):
download_fn_was_called = False
def download_fn(resource, root):
nonlocal download_fn_was_called
download_fn_was_called = True
return self._make_file(root, content="_", name=resource.file_name)
resource = self.DummyResource(
file_name="file.txt",
download_fn=download_fn,
)
resource.load(tmp_path)
assert download_fn_was_called, "`download_fn()` was never called"
# This tests the `"decompress"` literal as well as a custom callable
@pytest.mark.parametrize(
"preprocess",
[
"decompress",
lambda path: _decompress(str(path), remove_finished=True),
],
)
def test_preprocess_decompress(self, tmp_path, preprocess):
file_name = "file.txt.gz"
content = "sentinel"
def download_fn(resource, root):
file = root / resource.file_name
with gzip.open(file, "wb") as fh:
fh.write(content.encode())
return file
resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn)
dp = resource.load(tmp_path)
data = list(dp)
assert len(data) == 1
path, buffer = data[0]
assert path == str(tmp_path / file_name).replace(".gz", "")
assert buffer.read().decode() == content
def test_preprocess_extract(self, tmp_path):
files = None
def download_fn(resource, root):
nonlocal files
archive, files = self._make_tar(root, name=resource.file_name)
return archive
resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn)
dp = resource.load(tmp_path)
assert files is not None, "`download_fn()` was never called"
assert isinstance(dp, FileOpener)
actual = {path: buffer.read().decode() for path, buffer in dp}
expected = {
path.replace(resource.file_name, resource.file_name.split(".")[0]): content
for path, content in files.items()
}
assert actual == expected
def test_preprocess_only_after_download(self, tmp_path):
file = self._make_file(tmp_path, content="_")
def preprocess(path):
raise AssertionError("`preprocess` was called although the file was already present.")
resource = self.DummyResource(
file_name=file.name,
preprocess=preprocess,
)
resource.load(tmp_path)
class TestHttpResource: class TestHttpResource:
def test_resolve_to_http(self, mocker): def test_resolve_to_http(self, mocker):
file_name = "data.tar" file_name = "data.tar"
......
...@@ -2,7 +2,7 @@ import abc ...@@ -2,7 +2,7 @@ import abc
import hashlib import hashlib
import itertools import itertools
import pathlib import pathlib
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set
from urllib.parse import urlparse from urllib.parse import urlparse
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -32,7 +32,7 @@ class OnlineResource(abc.ABC): ...@@ -32,7 +32,7 @@ class OnlineResource(abc.ABC):
*, *,
file_name: str, file_name: str,
sha256: Optional[str] = None, sha256: Optional[str] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None, preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
) -> None: ) -> None:
self.file_name = file_name self.file_name = file_name
self.sha256 = sha256 self.sha256 = sha256
...@@ -50,14 +50,12 @@ class OnlineResource(abc.ABC): ...@@ -50,14 +50,12 @@ class OnlineResource(abc.ABC):
self._preprocess = preprocess self._preprocess = preprocess
@staticmethod @staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path: def _extract(file: pathlib.Path) -> None:
return pathlib.Path( extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
)
@staticmethod @staticmethod
def _decompress(file: pathlib.Path) -> pathlib.Path: def _decompress(file: pathlib.Path) -> None:
return pathlib.Path(_decompress(str(file), remove_finished=True)) _decompress(str(file), remove_finished=True)
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir(): if path.is_dir():
...@@ -91,32 +89,38 @@ class OnlineResource(abc.ABC): ...@@ -91,32 +89,38 @@ class OnlineResource(abc.ABC):
) -> IterDataPipe[Tuple[str, IO]]: ) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root) root = pathlib.Path(root)
path = root / self.file_name path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all. # with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
stem = path.name.replace("".join(path.suffixes), "") stem = path.name.replace("".join(path.suffixes), "")
# In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since def find_candidates() -> Set[pathlib.Path]:
# extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
# is always extracted in a folder with the corresponding file name. # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
folder_candidate = path.parent / stem # test split of the stanford-cars dataset uses the files
if folder_candidate.exists() and folder_candidate.is_dir(): # - cars_test.tgz
return self._loader(folder_candidate) # - cars_test_annos_withlabels.mat
# Globbing for `"cars_test*"` picks up both.
# If there is no folder, we look for all files that share the same stem as the raw file, but might have a candidates = {file for file in path.parent.glob(f"{stem}.*")}
# different suffix. folder_candidate = path.parent / stem
file_candidates = {file for file in path.parent.glob(stem + ".*")} if folder_candidate.exists():
# If we don't find anything, we download the raw file. candidates.add(folder_candidate)
if not file_candidates:
file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} return candidates
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if file_candidates == {path}: candidates = find_candidates()
if not candidates:
self.download(root, skip_integrity_check=skip_integrity_check)
if self._preprocess is not None: if self._preprocess is not None:
path = self._preprocess(path) self._preprocess(path)
# Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we candidates = find_candidates()
# want for the best I/O performance.
else: # We use the path with the fewest suffixes. This gives us the
path = min(file_candidates, key=lambda path: len(path.suffixes)) # extracted > decompressed > raw
return self._loader(path) # priority that we want for the best I/O performance.
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))
@abc.abstractmethod @abc.abstractmethod
def _download(self, root: pathlib.Path) -> None: def _download(self, root: pathlib.Path) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment