Unverified Commit 59dc9383 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

perform out of bounds check for single values and two tuples in ColorJitter (#7133)

parent d5091562
......@@ -317,7 +317,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
......
......@@ -1798,6 +1798,12 @@ def test_color_jitter():
color_jitter.__repr__()
@pytest.mark.parametrize("hue", [1, (-1, 1)])
def test_color_jitter_hue_out_of_bounds(hue):
with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
transforms.ColorJitter(hue=hue)
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_erasing(seed):
......
......@@ -77,12 +77,12 @@ class ColorJitter(Transform):
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
else:
elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2):
raise TypeError(f"{name} should be a single number or a sequence with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
@staticmethod
......
......@@ -1195,16 +1195,19 @@ class ColorJitter(torch.nn.Module):
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}")
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
return None
else:
return tuple(value)
@staticmethod
def get_params(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment