Unverified Commit 7fe51368 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add warnings checks for v2 namespaces and deprecated files (#7288)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent a192c95e
...@@ -8,8 +8,10 @@ import os ...@@ -8,8 +8,10 @@ import os
import pathlib import pathlib
import random import random
import shutil import shutil
import sys
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from subprocess import CalledProcessError, check_output, STDOUT
from typing import Callable, Sequence, Tuple, Union from typing import Callable, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -838,3 +840,22 @@ class InfoBase: ...@@ -838,3 +840,22 @@ class InfoBase:
if isinstance(device, torch.device): if isinstance(device, torch.device):
device = device.type device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict()) return self.closeness_kwargs.get((test_id, dtype, device), dict())
def assert_run_python_script(source_code):
"""Utility to check assertions in an independent Python subprocess.
The script provided in the source code should return 0 and not print
anything on stderr or stdout. Taken from scikit-learn test utils.
source_code (str): The Python source code to execute.
"""
with tempfile.NamedTemporaryFile(mode="wb") as f:
f.write(source_code.encode())
f.flush()
cmd = [sys.executable, f.name]
try:
out = check_output(cmd, stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
if out != b"":
raise AssertionError(out.decode())
...@@ -4,11 +4,12 @@ import numpy as np ...@@ -4,11 +4,12 @@ import numpy as np
import pytest import pytest
import torch import torch
import torchvision import torchvision
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
torchvision.disable_beta_transforms_warning() torchvision.disable_beta_transforms_warning()
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
def pytest_configure(config): def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems) # register an additional marker (see pytest_collection_modifyitems)
......
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import os import os
import random import random
import re import re
import textwrap
import warnings import warnings
from functools import partial from functools import partial
...@@ -24,7 +25,7 @@ try: ...@@ -24,7 +25,7 @@ try:
except ImportError: except ImportError:
stats = None stats = None
from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes from common_utils import assert_equal, assert_run_python_script, cycle_over, float_dtypes, int_dtypes
GRACE_HOPPER = get_file_path_2( GRACE_HOPPER = get_file_path_2(
...@@ -2266,5 +2267,35 @@ def test_random_grayscale_with_grayscale_input(): ...@@ -2266,5 +2267,35 @@ def test_random_grayscale_with_grayscale_input():
torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor) torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)
# TODO: remove in 0.17 when we can delete functional_pil.py and functional_tensor.py
@pytest.mark.parametrize(
"import_statement",
(
"from torchvision.transforms import functional_pil",
"from torchvision.transforms import functional_tensor",
"from torchvision.transforms.functional_tensor import resize",
"from torchvision.transforms.functional_pil import resize",
),
)
@pytest.mark.parametrize("from_private", (True, False))
def test_functional_deprecation_warning(import_statement, from_private):
if from_private:
import_statement = import_statement.replace("functional", "_functional")
source = f"""
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
{import_statement}
"""
else:
source = f"""
import pytest
with pytest.warns(UserWarning, match="removed in 0.17"):
{import_statement}
"""
assert_run_python_script(textwrap.dedent(source))
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -2,6 +2,7 @@ import itertools ...@@ -2,6 +2,7 @@ import itertools
import pathlib import pathlib
import random import random
import re import re
import textwrap
import warnings import warnings
from collections import defaultdict from collections import defaultdict
...@@ -14,6 +15,7 @@ import torchvision.transforms.v2 as transforms ...@@ -14,6 +15,7 @@ import torchvision.transforms.v2 as transforms
from common_utils import ( from common_utils import (
assert_equal, assert_equal,
assert_run_python_script,
cpu_and_gpu, cpu_and_gpu,
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
...@@ -2045,3 +2047,52 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -2045,3 +2047,52 @@ def test_sanitize_bounding_boxes_errors():
) )
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes) transforms.SanitizeBoundingBoxes()(different_sizes)
@pytest.mark.parametrize(
"import_statement",
(
"from torchvision.transforms import v2",
"import torchvision.transforms.v2",
"from torchvision.transforms.v2 import Resize",
"import torchvision.transforms.v2.functional",
"from torchvision.transforms.v2.functional import resize",
"from torchvision import datapoints",
"from torchvision.datapoints import Image",
"from torchvision.datasets import wrap_dataset_for_transforms_v2",
),
)
@pytest.mark.parametrize("call_disable_warning", (True, False))
def test_warnings_v2_namespaces(import_statement, call_disable_warning):
if call_disable_warning:
source = f"""
import warnings
import torchvision
torchvision.disable_beta_transforms_warning()
with warnings.catch_warnings():
warnings.simplefilter("error")
{import_statement}
"""
else:
source = f"""
import pytest
with pytest.warns(UserWarning, match="v2 namespaces are still Beta"):
{import_statement}
"""
assert_run_python_script(textwrap.dedent(source))
def test_no_warnings_v1_namespace():
source = """
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
import torchvision.transforms
from torchvision import transforms
import torchvision.transforms.functional
from torchvision.transforms import Resize
from torchvision.transforms.functional import resize
from torchvision import datasets
from torchvision.datasets import ImageNet
"""
assert_run_python_script(textwrap.dedent(source))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment