Unverified Commit 1a9ff0d7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Port remaining transforms tests (#7954)

parent 997384cf
......@@ -272,57 +272,6 @@ class TestSmoke:
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
@parametrize(
[
(
transform,
itertools.chain.from_iterable(
fn(
color_spaces=[
"GRAY",
"RGB",
],
dtypes=[torch.uint8],
extra_dims=[(), (4,)],
**(dict(num_frames=[3]) if fn is make_videos else dict()),
)
for fn in [
make_images,
make_vanilla_tensor_images,
make_pil_images,
make_videos,
]
),
)
for transform in (
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
transforms.AugMix(),
)
]
)
def test_auto_augment(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["RGB"], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
make_videos,
]
),
),
]
)
def test_normalize(self, transform, input):
transform(input)
@pytest.mark.parametrize(
"flat_inputs",
......@@ -385,40 +334,6 @@ def test_pure_tensor_heuristic(flat_inputs):
assert transform.was_applied(output, input)
class TestElasticTransform:
def test_assertions(self):
with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
transforms.ElasticTransform({})
with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
transforms.ElasticTransform([1.0, 2.0, 3.0])
with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
transforms.ElasticTransform(1.0, {})
with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc")
def test__get_params(self):
alpha = 2.0
sigma = 3.0
transform = transforms.ElasticTransform(alpha, sigma)
h, w = size = (24, 32)
image = make_image(size)
params = transform._get_params([image])
displacement = params["displacement"]
assert displacement.shape == (1, h, w, 2)
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
class TestTransform:
@pytest.mark.parametrize(
"inpt_type",
......@@ -705,25 +620,6 @@ class TestRandomResize:
assert min_size <= size < max_size
class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
"inpt",
[
torch.zeros(10, 3, 8, 8),
torch.zeros(1, 10, 3, 8, 8),
tv_tensors.Video(torch.zeros(1, 10, 3, 8, 8)),
],
)
def test__transform(self, inpt):
num_samples = 5
transform = transforms.UniformTemporalSubsample(num_samples)
output = transform(inpt)
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
......
......@@ -72,34 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Normalize,
legacy_transforms.Normalize,
[
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
),
ConsistencyConfig(
v2_transforms.FiveCrop,
legacy_transforms.FiveCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
v2_transforms.TenCrop,
legacy_transforms.TenCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
ArgsKwargs(18, vertical_flip=True),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
*[
ConsistencyConfig(
v2_transforms.LinearTransformation,
......@@ -147,65 +119,6 @@ CONSISTENCY_CONFIGS = [
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.RandomEqualize,
legacy_transforms.RandomEqualize,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
v2_transforms.RandomInvert,
legacy_transforms.RandomInvert,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
v2_transforms.RandomPosterize,
legacy_transforms.RandomPosterize,
[
ArgsKwargs(p=0, bits=5),
ArgsKwargs(p=1, bits=1),
ArgsKwargs(p=1, bits=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
v2_transforms.RandomSolarize,
legacy_transforms.RandomSolarize,
[
ArgsKwargs(p=0, threshold=0.5),
ArgsKwargs(p=1, threshold=0.3),
ArgsKwargs(p=1, threshold=0.99),
],
),
*[
ConsistencyConfig(
v2_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
],
ConsistencyConfig(
v2_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness,
[
ArgsKwargs(p=0, sharpness_factor=0.5),
ArgsKwargs(p=1, sharpness_factor=0.2),
ArgsKwargs(p=1, sharpness_factor=0.99),
],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
......@@ -230,22 +143,6 @@ CONSISTENCY_CONFIGS = [
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
ConsistencyConfig(
v2_transforms.AugMix,
legacy_transforms.AugMix,
),
ConsistencyConfig(
v2_transforms.AutoAugment,
legacy_transforms.AutoAugment,
),
ConsistencyConfig(
v2_transforms.RandAugment,
legacy_transforms.RandAugment,
),
ConsistencyConfig(
v2_transforms.TrivialAugmentWide,
legacy_transforms.TrivialAugmentWide,
),
]
......@@ -753,36 +650,9 @@ class TestRefSegTransforms:
(legacy_F.pil_to_tensor, {}),
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
(legacy_F.adjust_brightness, {}),
(legacy_F.adjust_contrast, {}),
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
(legacy_F.erase, {}),
(legacy_F.gaussian_blur, {}),
(legacy_F.invert, {}),
(legacy_F.posterize, {}),
(legacy_F.solarize, {}),
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill", "interpolation"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
......
This diff is collapsed.
This diff is collapsed.
import pytest
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
from transforms_v2_kernel_infos import KERNEL_INFOS
from transforms_v2_legacy_utils import InfoBase, TestMark
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
class PILKernelInfo(InfoBase):
def __init__(
self,
kernel,
*,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name=None,
):
super().__init__(id=kernel_name or kernel.__name__)
self.kernel = kernel
class DispatcherInfo(InfoBase):
_KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
def __init__(
self,
dispatcher,
*,
# Dictionary of types that map to the kernel the dispatcher dispatches to.
kernels,
# If omitted, no PIL dispatch test will be performed.
pil_kernel_info=None,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.dispatcher = dispatcher
self.kernels = kernels
self.pil_kernel_info = pil_kernel_info
kernel_infos = {}
for tv_tensor_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info:
raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
)
kernel_infos[tv_tensor_type] = kernel_info
self.kernel_infos = kernel_infos
def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(tv_tensor_type)
if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
sample_inputs = kernel_info.sample_inputs_fn()
if not filter_metadata:
yield from sample_inputs
return
import itertools
for args_kwargs in sample_inputs:
if hasattr(tv_tensor_type, "__annotations__"):
for name in itertools.chain(
tv_tensor_type.__annotations__.keys(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
yield args_kwargs
def xfail_jit(reason, *, condition=None):
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=condition,
)
def xfail_jit_python_scalar_arg(name, *, reason=None):
return xfail_jit(
reason or f"Python scalar int or float for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
)
skip_dispatch_tv_tensor = TestMark(
("TestDispatchers", "test_dispatch_tv_tensor"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
)
multi_crop_skips = [
TestMark(
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
]
multi_crop_skips.append(skip_dispatch_tv_tensor)
DISPATCHER_INFOS = [
DispatcherInfo(
F.elastic,
kernels={
tv_tensors.Image: F.elastic_image,
tv_tensors.Video: F.elastic_video,
tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
tv_tensors.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
DispatcherInfo(
F.equalize,
kernels={
tv_tensors.Image: F.equalize_image,
tv_tensors.Video: F.equalize_video,
},
pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
),
DispatcherInfo(
F.invert,
kernels={
tv_tensors.Image: F.invert_image,
tv_tensors.Video: F.invert_video,
},
pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
),
DispatcherInfo(
F.posterize,
kernels={
tv_tensors.Image: F.posterize_image,
tv_tensors.Video: F.posterize_video,
},
pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
),
DispatcherInfo(
F.solarize,
kernels={
tv_tensors.Image: F.solarize_image,
tv_tensors.Video: F.solarize_video,
},
pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
),
DispatcherInfo(
F.autocontrast,
kernels={
tv_tensors.Image: F.autocontrast_image,
tv_tensors.Video: F.autocontrast_video,
},
pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
),
DispatcherInfo(
F.adjust_sharpness,
kernels={
tv_tensors.Image: F.adjust_sharpness_image,
tv_tensors.Video: F.adjust_sharpness_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
),
DispatcherInfo(
F.adjust_contrast,
kernels={
tv_tensors.Image: F.adjust_contrast_image,
tv_tensors.Video: F.adjust_contrast_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
),
DispatcherInfo(
F.adjust_gamma,
kernels={
tv_tensors.Image: F.adjust_gamma_image,
tv_tensors.Video: F.adjust_gamma_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
),
DispatcherInfo(
F.adjust_hue,
kernels={
tv_tensors.Image: F.adjust_hue_image,
tv_tensors.Video: F.adjust_hue_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
),
DispatcherInfo(
F.adjust_saturation,
kernels={
tv_tensors.Image: F.adjust_saturation_image,
tv_tensors.Video: F.adjust_saturation_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
),
DispatcherInfo(
F.five_crop,
kernels={
tv_tensors.Image: F.five_crop_image,
tv_tensors.Video: F.five_crop_video,
},
pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
test_marks=[
xfail_jit_python_scalar_arg("size"),
*multi_crop_skips,
],
),
DispatcherInfo(
F.ten_crop,
kernels={
tv_tensors.Image: F.ten_crop_image,
tv_tensors.Video: F.ten_crop_video,
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
*multi_crop_skips,
],
pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
),
DispatcherInfo(
F.normalize,
kernels={
tv_tensors.Image: F.normalize_image,
tv_tensors.Video: F.normalize_video,
},
test_marks=[
xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"),
],
),
DispatcherInfo(
F.uniform_temporal_subsample,
kernels={
tv_tensors.Video: F.uniform_temporal_subsample_video,
},
test_marks=[
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.clamp_bounding_boxes,
kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
test_marks=[
skip_dispatch_tv_tensor,
],
),
]
This diff is collapsed.
......@@ -5,11 +5,9 @@ implemented there and must not use any of the utilities here.
The following legacy modules depend on this module
- transforms_v2_kernel_infos.py
- transforms_v2_dispatcher_infos.py
- test_transforms_v2_functional.py
- test_transforms_v2_consistency.py
- test_transforms.py
- test_transforms_v2.py
When all the logic is ported from the files above to test_transforms_v2_refactored.py, delete
all the legacy modules including this one and drop the _refactored prefix from the name.
......
......@@ -328,6 +328,11 @@ class RandomSolarize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomSolarize
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
params["threshold"] = float(params["threshold"])
return params
def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
......
......@@ -261,7 +261,7 @@ def clamp_bounding_boxes(
if torch.jit.is_scripting() or is_pure_tensor(inpt):
if format is None or canvas_size is None:
raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, tv_tensors.BoundingBoxes):
if format is not None or canvas_size is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment