Unverified Commit f1b4c7a6 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed sigma input type for v2.GaussianBlur (#7887)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent a2f8f8e9
...@@ -449,37 +449,6 @@ class TestRandomZoomOut: ...@@ -449,37 +449,6 @@ class TestRandomZoomOut:
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
class TestGaussianBlur:
def test_assertions(self):
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur([10, 12, 14])
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
transforms.GaussianBlur(4)
with pytest.raises(
TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats."
):
transforms.GaussianBlur(3, sigma=[1, 2, 3])
with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"):
transforms.GaussianBlur(3, sigma=-1.0)
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
transforms.GaussianBlur(3, sigma=[2.0, 1.0])
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
def test__get_params(self, sigma):
transform = transforms.GaussianBlur(3, sigma=sigma)
params = transform._get_params([])
if isinstance(sigma, float):
assert params["sigma"][0] == params["sigma"][1] == 10
else:
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]
class TestRandomPerspective: class TestRandomPerspective:
def test_assertions(self): def test_assertions(self):
with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"): with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"):
...@@ -503,24 +472,18 @@ class TestRandomPerspective: ...@@ -503,24 +472,18 @@ class TestRandomPerspective:
class TestElasticTransform: class TestElasticTransform:
def test_assertions(self): def test_assertions(self):
with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"): with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
transforms.ElasticTransform({}) transforms.ElasticTransform({})
with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"): with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
transforms.ElasticTransform([1.0, 2.0, 3.0]) transforms.ElasticTransform([1.0, 2.0, 3.0])
with pytest.raises(ValueError, match="alpha should be a sequence of floats"): with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
transforms.ElasticTransform([1, 2])
with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"):
transforms.ElasticTransform(1.0, {}) transforms.ElasticTransform(1.0, {})
with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"): with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0]) transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])
with pytest.raises(ValueError, match="sigma should be a sequence of floats"):
transforms.ElasticTransform(1.0, [1, 2])
with pytest.raises(TypeError, match="Got inappropriate fill arg"): with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc") transforms.ElasticTransform(1.0, 2.0, fill="abc")
......
...@@ -2859,3 +2859,46 @@ class TestErase: ...@@ -2859,3 +2859,46 @@ class TestErase:
_, output = transform(make_image(self.INPUT_SIZE), input) _, output = transform(make_image(self.INPUT_SIZE), input)
assert output is input assert output is input
class TestGaussianBlur:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("sigma", [5, (0.5, 2)])
def test_transform(self, make_input, device, sigma):
check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device))
def test_assertions(self):
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur([10, 12, 14])
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
transforms.GaussianBlur(4)
with pytest.raises(ValueError, match="If sigma is a sequence its length should be 1 or 2. Got 3"):
transforms.GaussianBlur(3, sigma=[1, 2, 3])
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
transforms.GaussianBlur(3, sigma=-1.0)
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
transforms.GaussianBlur(3, sigma=[2.0, 1.0])
with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
transforms.GaussianBlur(3, sigma={})
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]])
def test__get_params(self, sigma):
transform = transforms.GaussianBlur(3, sigma=sigma)
params = transform._get_params([])
if isinstance(sigma, float):
assert params["sigma"][0] == params["sigma"][1] == sigma
elif isinstance(sigma, list) and len(sigma) == 1:
assert params["sigma"][0] == params["sigma"][1] == sigma[0]
else:
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]
...@@ -21,7 +21,7 @@ from ._utils import ( ...@@ -21,7 +21,7 @@ from ._utils import (
_get_fill, _get_fill,
_setup_angle, _setup_angle,
_setup_fill_arg, _setup_fill_arg,
_setup_float_or_seq, _setup_number_or_seq,
_setup_size, _setup_size,
get_bounding_boxes, get_bounding_boxes,
has_all, has_all,
...@@ -1060,8 +1060,8 @@ class ElasticTransform(Transform): ...@@ -1060,8 +1060,8 @@ class ElasticTransform(Transform):
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.alpha = _setup_number_or_seq(alpha, "alpha")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.sigma = _setup_number_or_seq(sigma, "sigma")
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
self.fill = fill self.fill = fill
......
...@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten ...@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms, tv_tensors from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
# TODO: do we want/need to expose this? # TODO: do we want/need to expose this?
...@@ -198,17 +198,10 @@ class GaussianBlur(Transform): ...@@ -198,17 +198,10 @@ class GaussianBlur(Transform):
if ks <= 0 or ks % 2 == 0: if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.") raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, (int, float)): self.sigma = _setup_number_or_seq(sigma, "sigma")
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2) if not 0.0 < self.sigma[0] <= self.sigma[1]:
raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
......
...@@ -18,20 +18,23 @@ from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pu ...@@ -18,20 +18,23 @@ from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pu
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)): if not isinstance(arg, (int, float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size: if isinstance(arg, Sequence) and len(arg) not in (1, 2):
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}")
if isinstance(arg, Sequence): if isinstance(arg, Sequence):
for element in arg: for element in arg:
if not isinstance(element, float): if not isinstance(element, (int, float)):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}")
if isinstance(arg, float): if isinstance(arg, (int, float)):
arg = [float(arg), float(arg)] arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1: elif isinstance(arg, Sequence):
arg = [arg[0], arg[0]] if len(arg) == 1:
arg = [float(arg[0]), float(arg[0])]
else:
arg = [float(arg[0]), float(arg[1])]
return arg return arg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment