Unverified Commit 151e1622 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix HttpResource.resolve() with preprocessing (#5669)

* fix HttpResource.resolve() with preprocess set

* fix README

* add safe guard for invalid str inputs
parent 647016bd
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
import torch import torch
from datasets_utils import make_fake_flo_file from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
...@@ -45,3 +46,58 @@ def test_read_flo(tmpdir): ...@@ -45,3 +46,58 @@ def test_read_flo(tmpdir):
expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False)) expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
class TestHttpResource:
def test_resolve_to_http(self, mocker):
file_name = "data.tar"
original_url = f"http://downloads.pytorch.org/{file_name}"
redirected_url = original_url.replace("http", "https")
sha256_sentinel = "sha256_sentinel"
def preprocess_sentinel(path):
return path
original_resource = HttpResource(
original_url,
sha256=sha256_sentinel,
preprocess=preprocess_sentinel,
)
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
redirected_resource = original_resource.resolve()
assert isinstance(redirected_resource, HttpResource)
assert redirected_resource.url == redirected_url
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
def test_resolve_to_gdrive(self, mocker):
file_name = "data.tar"
original_url = f"http://downloads.pytorch.org/{file_name}"
id_sentinel = "id-sentinel"
redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view"
sha256_sentinel = "sha256_sentinel"
def preprocess_sentinel(path):
return path
original_resource = HttpResource(
original_url,
sha256=sha256_sentinel,
preprocess=preprocess_sentinel,
)
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
redirected_resource = original_resource.resolve()
assert isinstance(redirected_resource, GDriveResource)
assert redirected_resource.id == id_sentinel
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
...@@ -231,7 +231,7 @@ To generate the `$NAME.categories` file, run `python -m torchvision.prototype.da ...@@ -231,7 +231,7 @@ To generate the `$NAME.categories` file, run `python -m torchvision.prototype.da
### What if a resource file forms an I/O bottleneck? ### What if a resource file forms an I/O bottleneck?
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the
`decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex `preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
return `pathlib.Path` of the preprocessed file or folder. accepts `"decompress"` and `"extract"` to handle these common scenarios.
...@@ -32,7 +32,7 @@ class Caltech101(Dataset): ...@@ -32,7 +32,7 @@ class Caltech101(Dataset):
images = HttpResource( images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
decompress=True, preprocess="decompress",
) )
anns = HttpResource( anns = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
......
...@@ -51,29 +51,29 @@ class CUB200(Dataset): ...@@ -51,29 +51,29 @@ class CUB200(Dataset):
archive = HttpResource( archive = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz",
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
decompress=True, preprocess="decompress",
) )
segmentations = HttpResource( segmentations = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz",
sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f", sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f",
decompress=True, preprocess="decompress",
) )
return [archive, segmentations] return [archive, segmentations]
else: # config.year == "2010" else: # config.year == "2010"
split = HttpResource( split = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz",
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
decompress=True, preprocess="decompress",
) )
images = HttpResource( images = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz",
sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e", sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e",
decompress=True, preprocess="decompress",
) )
anns = HttpResource( anns = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/annotations.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200/annotations.tgz",
sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1", sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1",
decompress=True, preprocess="decompress",
) )
return [split, images, anns] return [split, images, anns]
......
...@@ -49,7 +49,7 @@ class DTD(Dataset): ...@@ -49,7 +49,7 @@ class DTD(Dataset):
archive = HttpResource( archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
decompress=True, preprocess="decompress",
) )
return [archive] return [archive]
......
...@@ -40,12 +40,12 @@ class OxfordIITPet(Dataset): ...@@ -40,12 +40,12 @@ class OxfordIITPet(Dataset):
images = HttpResource( images = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
decompress=True, preprocess="decompress",
) )
anns = HttpResource( anns = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91", sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91",
decompress=True, preprocess="decompress",
) )
return [images, anns] return [images, anns]
......
...@@ -91,7 +91,7 @@ class PCAM(Dataset): ...@@ -91,7 +91,7 @@ class PCAM(Dataset):
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [ # = [images resource, targets resource] return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True) GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
] ]
......
...@@ -23,6 +23,7 @@ from torchvision.datasets.utils import ( ...@@ -23,6 +23,7 @@ from torchvision.datasets.utils import (
_get_redirect_url, _get_redirect_url,
_get_google_drive_file_id, _get_google_drive_file_id,
) )
from typing_extensions import Literal
class OnlineResource(abc.ABC): class OnlineResource(abc.ABC):
...@@ -31,19 +32,22 @@ class OnlineResource(abc.ABC): ...@@ -31,19 +32,22 @@ class OnlineResource(abc.ABC):
*, *,
file_name: str, file_name: str,
sha256: Optional[str] = None, sha256: Optional[str] = None,
decompress: bool = False, preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None,
extract: bool = False,
) -> None: ) -> None:
self.file_name = file_name self.file_name = file_name
self.sha256 = sha256 self.sha256 = sha256
self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] if isinstance(preprocess, str):
if extract: if preprocess == "decompress":
self._preprocess = self._extract preprocess = self._decompress
elif decompress: elif preprocess == "extract":
self._preprocess = self._decompress preprocess = self._extract
else: else:
self._preprocess = None raise ValueError(
f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string,"
f"but got {preprocess} instead."
)
self._preprocess = preprocess
@staticmethod @staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path: def _extract(file: pathlib.Path) -> pathlib.Path:
...@@ -163,7 +167,6 @@ class HttpResource(OnlineResource): ...@@ -163,7 +167,6 @@ class HttpResource(OnlineResource):
"file_name", "file_name",
"sha256", "sha256",
"_preprocess", "_preprocess",
"_loader",
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment