Unverified Commit 151e1622 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix HttpResource.resolve() with preprocessing (#5669)

* fix HttpResource.resolve() with preprocess set

* fix README

* add safe guard for invalid str inputs
parent 647016bd
......@@ -5,6 +5,7 @@ import pytest
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
......@@ -45,3 +46,58 @@ def test_read_flo(tmpdir):
expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))
torch.testing.assert_close(actual, expected)
class TestHttpResource:
def test_resolve_to_http(self, mocker):
file_name = "data.tar"
original_url = f"http://downloads.pytorch.org/{file_name}"
redirected_url = original_url.replace("http", "https")
sha256_sentinel = "sha256_sentinel"
def preprocess_sentinel(path):
return path
original_resource = HttpResource(
original_url,
sha256=sha256_sentinel,
preprocess=preprocess_sentinel,
)
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
redirected_resource = original_resource.resolve()
assert isinstance(redirected_resource, HttpResource)
assert redirected_resource.url == redirected_url
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
def test_resolve_to_gdrive(self, mocker):
file_name = "data.tar"
original_url = f"http://downloads.pytorch.org/{file_name}"
id_sentinel = "id-sentinel"
redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view"
sha256_sentinel = "sha256_sentinel"
def preprocess_sentinel(path):
return path
original_resource = HttpResource(
original_url,
sha256=sha256_sentinel,
preprocess=preprocess_sentinel,
)
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
redirected_resource = original_resource.resolve()
assert isinstance(redirected_resource, GDriveResource)
assert redirected_resource.id == id_sentinel
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
......@@ -231,7 +231,7 @@ To generate the `$NAME.categories` file, run `python -m torchvision.prototype.da
### What if a resource file forms an I/O bottleneck?
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the
`decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex
cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should
return `pathlib.Path` of the preprocessed file or folder.
the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the
`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
accepts `"decompress"` and `"extract"` to handle these common scenarios.
......@@ -32,7 +32,7 @@ class Caltech101(Dataset):
images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
decompress=True,
preprocess="decompress",
)
anns = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
......
......@@ -51,29 +51,29 @@ class CUB200(Dataset):
archive = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz",
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
decompress=True,
preprocess="decompress",
)
segmentations = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz",
sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f",
decompress=True,
preprocess="decompress",
)
return [archive, segmentations]
else: # config.year == "2010"
split = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz",
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
decompress=True,
preprocess="decompress",
)
images = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz",
sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e",
decompress=True,
preprocess="decompress",
)
anns = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/annotations.tgz",
sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1",
decompress=True,
preprocess="decompress",
)
return [split, images, anns]
......
......@@ -49,7 +49,7 @@ class DTD(Dataset):
archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
decompress=True,
preprocess="decompress",
)
return [archive]
......
......@@ -40,12 +40,12 @@ class OxfordIITPet(Dataset):
images = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
decompress=True,
preprocess="decompress",
)
anns = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91",
decompress=True,
preprocess="decompress",
)
return [images, anns]
......
......@@ -91,7 +91,7 @@ class PCAM(Dataset):
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True)
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
]
......
......@@ -23,6 +23,7 @@ from torchvision.datasets.utils import (
_get_redirect_url,
_get_google_drive_file_id,
)
from typing_extensions import Literal
class OnlineResource(abc.ABC):
......@@ -31,19 +32,22 @@ class OnlineResource(abc.ABC):
*,
file_name: str,
sha256: Optional[str] = None,
decompress: bool = False,
extract: bool = False,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]]
if extract:
self._preprocess = self._extract
elif decompress:
self._preprocess = self._decompress
else:
self._preprocess = None
if isinstance(preprocess, str):
if preprocess == "decompress":
preprocess = self._decompress
elif preprocess == "extract":
preprocess = self._extract
else:
raise ValueError(
f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string,"
f"but got {preprocess} instead."
)
self._preprocess = preprocess
@staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path:
......@@ -163,7 +167,6 @@ class HttpResource(OnlineResource):
"file_name",
"sha256",
"_preprocess",
"_loader",
)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment