"vscode:/vscode.git/clone" did not exist on "f9b074ce3e9b9ca94f3d8a7ff2bb47ca9a8284f1"
Unverified Commit ac1512b6 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added wrap_dataset_for_transforms_v2 into datasets and handled beta w… (#7279)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 56b04976
......@@ -55,3 +55,9 @@ jobs:
# Run Tests
python3 -m torch.utils.collect_env
python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20
# Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2"
# We keep them separate to avoid any side effects due to warnings / imports.
# TODO: Remove this and add proper tests (possibly using a sub-process solution as described
# in https://github.com/pytorch/vision/pull/7269).
python3 -m pytest -v test/check_v2_dataset_warnings.py
import pytest
def test_warns_if_imported_from_datasets(mocker):
mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True)
import torchvision
with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING):
from torchvision.datasets import wrap_dataset_for_transforms_v2
assert callable(wrap_dataset_for_transforms_v2)
@pytest.mark.filterwarnings("error")
def test_no_warns_if_imported_from_datasets():
from torchvision.datasets import wrap_dataset_for_transforms_v2
assert callable(wrap_dataset_for_transforms_v2)
......@@ -584,8 +584,8 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints import wrap_dataset_for_transforms_v2
from torchvision.datapoints._datapoint import Datapoint
from torchvision.datasets import wrap_dataset_for_transforms_v2
try:
with self.create_dataset(config) as (dataset, _):
......
......@@ -8,6 +8,7 @@ import os
import pathlib
import pickle
import random
import re
import shutil
import string
import unittest
......@@ -3309,5 +3310,47 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
pass
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datasets.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
from torchvision import datapoints
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
if __name__ == "__main__":
unittest.main()
import re
import pytest
import torch
from PIL import Image
from torchvision import datapoints, datasets
from torchvision import datapoints
from torchvision.prototype import datapoints as proto_datapoints
......@@ -163,43 +161,3 @@ def test_bbox_instance(data, format):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings
......
......@@ -128,3 +128,18 @@ __all__ = (
"InStereo2k",
"ETH3DStereo",
)
# We override current module's attributes to handle the import:
# from torchvision.datasets import wrap_dataset_for_transforms_v2
# with beta state v2 warning from torchvision.datapoints
# We also want to avoid raising the warning when importing other attributes
# from torchvision.datasets
# Ref: https://peps.python.org/pep-0562/
def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2
return wrap_dataset_for_transforms_v2
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment