Unverified Commit 59dc9383 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

perform out of bounds check for single values and two tuples in ColorJitter (#7133)

parent d5091562
...@@ -317,7 +317,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -317,7 +317,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(saturation=(0.8, 0.9)), ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3), ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)), ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6), ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
], ],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5}, closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
), ),
......
...@@ -1798,6 +1798,12 @@ def test_color_jitter(): ...@@ -1798,6 +1798,12 @@ def test_color_jitter():
color_jitter.__repr__() color_jitter.__repr__()
@pytest.mark.parametrize("hue", [1, (-1, 1)])
def test_color_jitter_hue_out_of_bounds(hue):
with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
transforms.ColorJitter(hue=hue)
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_erasing(seed): def test_random_erasing(seed):
......
...@@ -77,12 +77,12 @@ class ColorJitter(Transform): ...@@ -77,12 +77,12 @@ class ColorJitter(Transform):
value = [center - value, center + value] value = [center - value, center + value]
if clip_first_on_zero: if clip_first_on_zero:
value[0] = max(value[0], 0.0) value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2: elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2):
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
else:
raise TypeError(f"{name} should be a single number or a sequence with length 2.") raise TypeError(f"{name} should be a single number or a sequence with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
@staticmethod @staticmethod
......
...@@ -1195,16 +1195,19 @@ class ColorJitter(torch.nn.Module): ...@@ -1195,16 +1195,19 @@ class ColorJitter(torch.nn.Module):
if clip_first_on_zero: if clip_first_on_zero:
value[0] = max(value[0], 0.0) value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2: elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]: value = [float(value[0]), float(value[1])]
raise ValueError(f"{name} values should be between {bound}")
else: else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
# if value is 0 or (1., 1.) for brightness/contrast/saturation # if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing # or (0., 0.) for hue, do nothing
if value[0] == value[1] == center: if value[0] == value[1] == center:
value = None return None
return value else:
return tuple(value)
@staticmethod @staticmethod
def get_params( def get_params(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment