Unverified Commit 924b1626 authored by Thien Tran's avatar Thien Tran Committed by GitHub
Browse files

Add JPEG augmentation (#8316)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 2ba586d5
...@@ -407,6 +407,7 @@ Miscellaneous ...@@ -407,6 +407,7 @@ Miscellaneous
v2.SanitizeBoundingBoxes v2.SanitizeBoundingBoxes
v2.ClampBoundingBoxes v2.ClampBoundingBoxes
v2.UniformTemporalSubsample v2.UniformTemporalSubsample
v2.JPEG
Functionals Functionals
...@@ -419,6 +420,7 @@ Functionals ...@@ -419,6 +420,7 @@ Functionals
v2.functional.sanitize_bounding_boxes v2.functional.sanitize_bounding_boxes
v2.functional.clamp_bounding_boxes v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample v2.functional.uniform_temporal_subsample
v2.functional.jpeg
.. _conversion_transforms: .. _conversion_transforms:
......
...@@ -237,6 +237,17 @@ equalizer = v2.RandomEqualize() ...@@ -237,6 +237,17 @@ equalizer = v2.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)] equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot([orig_img] + equalized_imgs) plot([orig_img] + equalized_imgs)
# %%
# JPEG
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.v2.JPEG` transform
# (see also :func:`~torchvision.transforms.v2.functional.jpeg`)
# applies JPEG compression to the given image with random
# degree of compression.
jpeg = v2.JPEG((5, 50))
jpeg_imgs = [jpeg(orig_img) for _ in range(4)]
plot([orig_img] + jpeg_imgs)
# %% # %%
# Augmentation Transforms # Augmentation Transforms
# ----------------------- # -----------------------
......
...@@ -5932,3 +5932,86 @@ class TestSanitizeBoundingBoxes: ...@@ -5932,3 +5932,86 @@ class TestSanitizeBoundingBoxes:
with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"): with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
F.sanitize_bounding_boxes(good_bbox.tolist()) F.sanitize_bounding_boxes(good_bbox.tolist())
class TestJPEG:
@pytest.mark.parametrize("quality", [5, 75])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
def test_kernel_image(self, quality, color_space):
check_kernel(F.jpeg_image, make_image(color_space=color_space), quality=quality)
def test_kernel_video(self):
check_kernel(F.jpeg_video, make_video(), quality=5)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_functional(self, make_input):
check_functional(F.jpeg, make_input(), quality=5)
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.jpeg_image, torch.Tensor),
(F._jpeg_image_pil, PIL.Image.Image),
(F.jpeg_image, tv_tensors.Image),
(F.jpeg_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.jpeg, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
def test_transform(self, make_input, quality, color_space):
check_transform(transforms.JPEG(quality=quality), make_input(color_space=color_space))
@pytest.mark.parametrize("quality", [5])
def test_functional_image_correctness(self, quality):
image = make_image()
actual = F.jpeg(image, quality=quality)
expected = F.to_image(F.jpeg(F.to_pil_image(image), quality=quality))
# NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder
torch.testing.assert_close(actual, expected, rtol=0, atol=1)
@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, quality, color_space, seed):
image = make_image(color_space=color_space)
transform = transforms.JPEG(quality=quality)
with freeze_rng_state():
torch.manual_seed(seed)
actual = transform(image)
torch.manual_seed(seed)
expected = F.to_image(transform(F.to_pil_image(image)))
torch.testing.assert_close(actual, expected, rtol=0, atol=1)
@pytest.mark.parametrize("quality", [5, (10, 20)])
@pytest.mark.parametrize("seed", list(range(10)))
def test_transform_get_params_bounds(self, quality, seed):
transform = transforms.JPEG(quality=quality)
with freeze_rng_state():
torch.manual_seed(seed)
params = transform._get_params([])
if isinstance(quality, int):
assert params["quality"] == quality
else:
assert quality[0] <= params["quality"] <= quality[1]
@pytest.mark.parametrize("quality", [[0], [0, 0, 0]])
def test_transform_sequence_len_error(self, quality):
with pytest.raises(ValueError, match="quality should be a sequence of length 2"):
transforms.JPEG(quality=quality)
@pytest.mark.parametrize("quality", [-1, 0, 150])
def test_transform_invalid_quality_error(self, quality):
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
transforms.JPEG(quality=quality)
...@@ -4,7 +4,7 @@ from . import functional # usort: skip ...@@ -4,7 +4,7 @@ from . import functional # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
from ._augment import CutMix, MixUp, RandomErasing from ._augment import CutMix, JPEG, MixUp, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import ( from ._color import (
ColorJitter, ColorJitter,
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -11,7 +11,7 @@ from torchvision import transforms as _transforms, tv_tensors ...@@ -11,7 +11,7 @@ from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -317,3 +317,39 @@ class CutMix(_BaseMixUpCutMix): ...@@ -317,3 +317,39 @@ class CutMix(_BaseMixUpCutMix):
return output return output
else: else:
return inpt return inpt
class JPEG(Transform):
"""Apply JPEG compression and decompression to the given images.
If the input is a :class:`torch.Tensor`, it is expected
to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions.
Args:
quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
If quality is a sequence like (min, max), it specifies the range of JPEG quality to
randomly select from (inclusive of both ends).
Returns:
image with JPEG compression.
"""
def __init__(self, quality: Union[int, Sequence[int]]):
super().__init__()
if isinstance(quality, int):
quality = [quality, quality]
else:
_check_sequence_input(quality, "quality", req_sizes=(2,))
if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")
self.quality = quality
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
return dict(quality=quality)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])
...@@ -24,7 +24,7 @@ from ._meta import ( ...@@ -24,7 +24,7 @@ from ._meta import (
get_size, get_size,
) # usort: skip ) # usort: skip
from ._augment import _erase_image_pil, erase, erase_image, erase_video from ._augment import _erase_image_pil, _jpeg_image_pil, erase, erase_image, erase_video, jpeg, jpeg_image, jpeg_video
from ._color import ( from ._color import (
_adjust_brightness_image_pil, _adjust_brightness_image_pil,
_adjust_contrast_image_pil, _adjust_contrast_image_pil,
......
import io
import PIL.Image import PIL.Image
import torch import torch
from torchvision import tv_tensors from torchvision import tv_tensors
from torchvision.io import decode_jpeg, encode_jpeg
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -53,3 +56,43 @@ def erase_video( ...@@ -53,3 +56,43 @@ def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.JPEG` for details."""
if torch.jit.is_scripting():
return jpeg_image(image, quality=quality)
_log_api_usage_once(jpeg)
kernel = _get_kernel(jpeg, type(image))
return kernel(image, quality=quality)
@_register_kernel_internal(jpeg, torch.Tensor)
@_register_kernel_internal(jpeg, tv_tensors.Image)
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
original_shape = image.shape
image = image.view((-1,) + image.shape[-3:])
if image.shape[0] == 0: # degenerate
return image.reshape(original_shape).clone()
image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])]
image = torch.stack(image, dim=0).view(original_shape)
return image
@_register_kernel_internal(jpeg, tv_tensors.Video)
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
return jpeg_image(video, quality=quality)
@_register_kernel_internal(jpeg, PIL.Image.Image)
def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image:
raw_jpeg = io.BytesIO()
image.save(raw_jpeg, format="JPEG", quality=quality)
# we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile
# which is a sub-class of PIL.Image.Image. this will fail check_transform() test.
return PIL.Image.open(raw_jpeg).copy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment