Unverified Commit ac1512b6 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added wrap_dataset_for_transforms_v2 into datasets and handled beta w… (#7279)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 56b04976
...@@ -41,7 +41,7 @@ jobs: ...@@ -41,7 +41,7 @@ jobs:
# Create Conda Env # Create Conda Env
conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy
conda activate /work/ci_env conda activate /work/ci_env
# Install PyTorch, Torchvision, and testing libraries # Install PyTorch, Torchvision, and testing libraries
set -ex set -ex
conda install \ conda install \
...@@ -55,3 +55,9 @@ jobs: ...@@ -55,3 +55,9 @@ jobs:
# Run Tests # Run Tests
python3 -m torch.utils.collect_env python3 -m torch.utils.collect_env
python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20 python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20
# Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2"
# We keep them separate to avoid any side effects due to warnings / imports.
# TODO: Remove this and add proper tests (possibly using a sub-process solution as described
# in https://github.com/pytorch/vision/pull/7269).
python3 -m pytest -v test/check_v2_dataset_warnings.py
import pytest
def test_warns_if_imported_from_datasets(mocker):
mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True)
import torchvision
with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING):
from torchvision.datasets import wrap_dataset_for_transforms_v2
assert callable(wrap_dataset_for_transforms_v2)
@pytest.mark.filterwarnings("error")
def test_no_warns_if_imported_from_datasets():
from torchvision.datasets import wrap_dataset_for_transforms_v2
assert callable(wrap_dataset_for_transforms_v2)
...@@ -584,8 +584,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -584,8 +584,8 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_transforms_v2_wrapper(self, config): def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints import wrap_dataset_for_transforms_v2
from torchvision.datapoints._datapoint import Datapoint from torchvision.datapoints._datapoint import Datapoint
from torchvision.datasets import wrap_dataset_for_transforms_v2
try: try:
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, _):
......
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import pathlib import pathlib
import pickle import pickle
import random import random
import re
import shutil import shutil
import string import string
import unittest import unittest
...@@ -3309,5 +3310,47 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -3309,5 +3310,47 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
pass pass
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datasets.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
from torchvision import datapoints
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import re
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from torchvision import datapoints, datasets from torchvision import datapoints
from torchvision.prototype import datapoints as proto_datapoints from torchvision.prototype import datapoints as proto_datapoints
...@@ -163,43 +161,3 @@ def test_bbox_instance(data, format): ...@@ -163,43 +161,3 @@ def test_bbox_instance(data, format):
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper()) format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format assert bboxes.format == format
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
if _WARN_ABOUT_BETA_TRANSFORMS: if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings import warnings
......
...@@ -128,3 +128,18 @@ __all__ = ( ...@@ -128,3 +128,18 @@ __all__ = (
"InStereo2k", "InStereo2k",
"ETH3DStereo", "ETH3DStereo",
) )
# We override current module's attributes to handle the import:
# from torchvision.datasets import wrap_dataset_for_transforms_v2
# with beta state v2 warning from torchvision.datapoints
# We also want to avoid raising the warning when importing other attributes
# from torchvision.datasets
# Ref: https://peps.python.org/pep-0562/
def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2
return wrap_dataset_for_transforms_v2
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment