Unverified Commit f483e71b authored by Loi Ly's avatar Loi Ly Committed by GitHub
Browse files

Added gray image support to `adjust_saturation` function (#4480)

* update channels parameter to every calling to check_functional_vs_PIL_vs_scripted

* update adjust_saturation

* update docstrings for functional transformations

* parametrize channels

* update docstring of ColorJitter class

* move channels to class's parameter

* remove testing channels for geometric transforms

* revert redundant changes

* revert redundant changes

* update grayscale test cases for randaugment, autoaugment, trivialaugment

* update docstrings of randaugment, autoaugment, trivialaugment

* update docstring of ColorJitter

* fix adjust_hue's docstring

* change test equal tolerance

* refactor grayscale tests

* make get_grayscale_test_image private
parent 3e27eb21
......@@ -148,7 +148,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
return batch_tensor
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=1e-6)
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
......
......@@ -681,7 +681,8 @@ def check_functional_vs_PIL_vs_scripted(
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
def test_adjust_brightness(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_brightness(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_brightness,
F_pil.adjust_brightness,
......@@ -689,12 +690,14 @@ def test_adjust_brightness(device, dtype, config):
config,
device,
dtype,
channels,
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
def test_invert(device, dtype):
@pytest.mark.parametrize('channels', [1, 3])
def test_invert(device, dtype, channels):
check_functional_vs_PIL_vs_scripted(
F.invert,
F_pil.invert,
......@@ -702,6 +705,7 @@ def test_invert(device, dtype):
{},
device,
dtype,
channels,
tol=1.0,
agg_method="max"
)
......@@ -709,7 +713,8 @@ def test_invert(device, dtype):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)])
def test_posterize(device, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_posterize(device, config, channels):
check_functional_vs_PIL_vs_scripted(
F.posterize,
F_pil.posterize,
......@@ -717,6 +722,7 @@ def test_posterize(device, config):
config,
device,
dtype=None,
channels=channels,
tol=1.0,
agg_method="max",
)
......@@ -724,7 +730,8 @@ def test_posterize(device, config):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
def test_solarize1(device, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_solarize1(device, config, channels):
check_functional_vs_PIL_vs_scripted(
F.solarize,
F_pil.solarize,
......@@ -732,6 +739,7 @@ def test_solarize1(device, config):
config,
device,
dtype=None,
channels=channels,
tol=1.0,
agg_method="max",
)
......@@ -740,7 +748,8 @@ def test_solarize1(device, config):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
def test_solarize2(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_solarize2(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
......@@ -748,6 +757,7 @@ def test_solarize2(device, dtype, config):
config,
device,
dtype,
channels,
tol=1.0,
agg_method="max",
)
......@@ -756,7 +766,8 @@ def test_solarize2(device, dtype, config):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_sharpness(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_sharpness(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_sharpness,
F_pil.adjust_sharpness,
......@@ -764,12 +775,14 @@ def test_adjust_sharpness(device, dtype, config):
config,
device,
dtype,
channels,
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
def test_autocontrast(device, dtype):
@pytest.mark.parametrize('channels', [1, 3])
def test_autocontrast(device, dtype, channels):
check_functional_vs_PIL_vs_scripted(
F.autocontrast,
F_pil.autocontrast,
......@@ -777,13 +790,15 @@ def test_autocontrast(device, dtype):
{},
device,
dtype,
channels,
tol=1.0,
agg_method="max"
)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_equalize(device):
@pytest.mark.parametrize('channels', [1, 3])
def test_equalize(device, channels):
torch.use_deterministic_algorithms(False)
check_functional_vs_PIL_vs_scripted(
F.equalize,
......@@ -792,6 +807,7 @@ def test_equalize(device):
{},
device,
dtype=None,
channels=channels,
tol=1.0,
agg_method="max",
)
......@@ -809,28 +825,31 @@ def test_adjust_contrast(device, dtype, config, channels):
config,
device,
dtype,
channels=channels
channels
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
def test_adjust_saturation(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_saturation(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
config,
device,
dtype
dtype,
channels
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
def test_adjust_hue(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_hue(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_hue,
F_pil.adjust_hue,
......@@ -838,6 +857,7 @@ def test_adjust_hue(device, dtype, config):
config,
device,
dtype,
channels,
tol=16.1,
agg_method="max"
)
......@@ -846,7 +866,8 @@ def test_adjust_hue(device, dtype, config):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
def test_adjust_gamma(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_gamma(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_gamma,
F_pil.adjust_gamma,
......@@ -854,6 +875,7 @@ def test_adjust_gamma(device, dtype, config):
config,
device,
dtype,
channels,
)
......
......@@ -26,6 +26,12 @@ GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
def _get_grayscale_test_image(img, fill=None):
img = img.convert('L')
fill = (fill[0], ) if isinstance(fill, tuple) else fill
return img, fill
class TestConvertImageDtype:
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes()))
def test_float_to_float(self, input_dtype, output_dtype):
......@@ -1482,9 +1488,12 @@ def test_five_crop(single_dim):
@pytest.mark.parametrize('policy', transforms.AutoAugmentPolicy)
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
def test_autoaugment(policy, fill):
@pytest.mark.parametrize('grayscale', [True, False])
def test_autoaugment(policy, fill, grayscale):
random.seed(42)
img = Image.open(GRACE_HOPPER)
if grayscale:
img, fill = _get_grayscale_test_image(img, fill)
transform = transforms.AutoAugment(policy=policy, fill=fill)
for _ in range(100):
img = transform(img)
......@@ -1494,9 +1503,12 @@ def test_autoaugment(policy, fill):
@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
def test_randaugment(num_ops, magnitude, fill):
@pytest.mark.parametrize('grayscale', [True, False])
def test_randaugment(num_ops, magnitude, fill, grayscale):
random.seed(42)
img = Image.open(GRACE_HOPPER)
if grayscale:
img, fill = _get_grayscale_test_image(img, fill)
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
for _ in range(100):
img = transform(img)
......@@ -1505,9 +1517,12 @@ def test_randaugment(num_ops, magnitude, fill):
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugmentwide(fill, num_magnitude_bins):
@pytest.mark.parametrize('grayscale', [True, False])
def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
random.seed(42)
img = Image.open(GRACE_HOPPER)
if grayscale:
img, fill = _get_grayscale_test_image(img, fill)
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
for _ in range(100):
img = transform(img)
......
......@@ -47,10 +47,10 @@ def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors,
assert_equal(transformed_batch, s_transformed_batch, msg=msg)
def _test_functional_op(f, device, fn_kwargs=None, test_exact_match=True, **match_kwargs):
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs):
fn_kwargs = fn_kwargs or {}
tensor, pil_img = _create_data(height=10, width=10, device=device)
tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
......@@ -59,7 +59,7 @@ def _test_functional_op(f, device, fn_kwargs=None, test_exact_match=True, **matc
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **match_kwargs):
def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
# TODO: change the name: it's not a method, it's a class.
meth_kwargs = meth_kwargs or {}
......@@ -67,7 +67,7 @@ def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **ma
f = method(**meth_kwargs)
scripted_fn = torch.jit.script(f)
tensor, pil_img = _create_data(26, 34, device=device)
tensor, pil_img = _create_data(26, 34, channels, device=device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
......@@ -82,16 +82,16 @@ def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **ma
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device)
_test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
_test_functional_op(func, device, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
_test_class_op(method, device, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
_test_functional_op(func, device, channels, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
_test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -109,54 +109,56 @@ def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_
(F.equalize, T.RandomEqualize, None, {})
]
)
def test_random(func, method, device, fn_kwargs, match_kwargs):
_test_op(func, method, device, fn_kwargs, fn_kwargs, **match_kwargs)
@pytest.mark.parametrize('channels', [1, 3])
def test_random(func, method, device, channels, fn_kwargs, match_kwargs):
_test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('channels', [1, 3])
class TestColorJitter:
@pytest.mark.parametrize('brightness', [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
def test_color_jitter_brightness(self, brightness, device):
def test_color_jitter_brightness(self, brightness, device, channels):
tol = 1.0 + 1e-10
meth_kwargs = {"brightness": brightness}
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
tol=tol, agg_method="max"
tol=tol, agg_method="max", channels=channels,
)
@pytest.mark.parametrize('contrast', [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
def test_color_jitter_contrast(self, contrast, device):
def test_color_jitter_contrast(self, contrast, device, channels):
tol = 1.0 + 1e-10
meth_kwargs = {"contrast": contrast}
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
tol=tol, agg_method="max"
tol=tol, agg_method="max", channels=channels
)
@pytest.mark.parametrize('saturation', [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
def test_color_jitter_saturation(self, saturation, device):
def test_color_jitter_saturation(self, saturation, device, channels):
tol = 1.0 + 1e-10
meth_kwargs = {"saturation": saturation}
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
tol=tol, agg_method="max"
tol=tol, agg_method="max", channels=channels
)
@pytest.mark.parametrize('hue', [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
def test_color_jitter_hue(self, hue, device):
def test_color_jitter_hue(self, hue, device, channels):
meth_kwargs = {"hue": hue}
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
tol=16.1, agg_method="max"
tol=16.1, agg_method="max", channels=channels
)
def test_color_jitter_all(self, device):
def test_color_jitter_all(self, device, channels):
# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
tol=12.1, agg_method="max"
tol=12.1, agg_method="max", channels=channels
)
......@@ -226,7 +228,7 @@ def test_crop(device):
def test_crop_pad(size, padding_config, device):
config = dict(padding_config)
config["size"] = size
_test_class_op(T.RandomCrop, device, config)
_test_class_op(T.RandomCrop, device, meth_kwargs=config)
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -709,9 +711,10 @@ def test_random_apply(device):
{"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
{"kernel_size": [23], "sigma": 0.75}
])
def test_gaussian_blur(device, meth_kwargs):
@pytest.mark.parametrize('channels', [1, 3])
def test_gaussian_blur(device, channels, meth_kwargs):
tol = 1.0 + 1e-10
_test_class_op(
T.GaussianBlur, meth_kwargs=meth_kwargs,
T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels,
test_exact_match=False, device=device, agg_method="max", tol=tol
)
......@@ -65,8 +65,8 @@ class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "RGB".
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
......@@ -249,8 +249,8 @@ class RandAugment(torch.nn.Module):
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "RGB".
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int): Number of augmentation transformations to apply sequentially.
......@@ -333,8 +333,8 @@ class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "RGB".
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int): The number of different magnitude values.
......
......@@ -791,7 +791,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
......@@ -811,7 +811,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
......@@ -842,9 +842,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
......
......@@ -215,7 +215,10 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
_assert_image_tensor(img)
_assert_channels(img, [3])
_assert_channels(img, [1, 3])
if get_image_num_channels(img) == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
......
......@@ -1098,8 +1098,8 @@ class LinearTransformation(torch.nn.Module):
class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment