_augment.py 3.26 KB
Newer Older
Thien Tran's avatar
Thien Tran committed
1
2
import io

3
4
5
import PIL.Image

import torch
6
from torchvision import tv_tensors
Thien Tran's avatar
Thien Tran committed
7
from torchvision.io import decode_jpeg, encode_jpeg
8
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
9
from torchvision.utils import _log_api_usage_once
10

11
from ._utils import _get_kernel, _register_kernel_internal
12

13

14
def erase(
15
    inpt: torch.Tensor,
16
17
18
19
20
21
    i: int,
    j: int,
    h: int,
    w: int,
    v: torch.Tensor,
    inplace: bool = False,
22
) -> torch.Tensor:
23
    """See :class:`~torchvision.transforms.v2.RandomErase` for details."""
24
    if torch.jit.is_scripting():
25
        return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
26
27
28
29
30

    _log_api_usage_once(erase)

    kernel = _get_kernel(erase, type(inpt))
    return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
31
32


33
@_register_kernel_internal(erase, torch.Tensor)
34
@_register_kernel_internal(erase, tv_tensors.Image)
35
def erase_image(
36
37
38
39
40
41
42
    image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
    if not inplace:
        image = image.clone()

    image[..., i : i + h, j : j + w] = v
    return image
43
44


45
@_register_kernel_internal(erase, PIL.Image.Image)
46
def _erase_image_pil(
47
    image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
48
) -> PIL.Image.Image:
49
    t_img = pil_to_tensor(image)
50
    output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
51
    return to_pil_image(output, mode=image.mode)
52
53


54
@_register_kernel_internal(erase, tv_tensors.Video)
55
56
57
def erase_video(
    video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
58
    return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
Thien Tran's avatar
Thien Tran committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
    """See :class:`~torchvision.transforms.v2.JPEG` for details."""
    if torch.jit.is_scripting():
        return jpeg_image(image, quality=quality)

    _log_api_usage_once(jpeg)

    kernel = _get_kernel(jpeg, type(image))
    return kernel(image, quality=quality)


@_register_kernel_internal(jpeg, torch.Tensor)
@_register_kernel_internal(jpeg, tv_tensors.Image)
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
    original_shape = image.shape
    image = image.view((-1,) + image.shape[-3:])

    if image.shape[0] == 0:  # degenerate
        return image.reshape(original_shape).clone()

81
82
83
84
85
86
87
88
    images = []
    for i in range(image.shape[0]):
        encoded_image = encode_jpeg(image[i], quality=quality)
        assert isinstance(encoded_image, torch.Tensor)  # For torchscript
        images.append(decode_jpeg(encoded_image))

    images = torch.stack(images, dim=0).view(original_shape)
    return images
Thien Tran's avatar
Thien Tran committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103


@_register_kernel_internal(jpeg, tv_tensors.Video)
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
    return jpeg_image(video, quality=quality)


@_register_kernel_internal(jpeg, PIL.Image.Image)
def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image:
    raw_jpeg = io.BytesIO()
    image.save(raw_jpeg, format="JPEG", quality=quality)

    # we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile
    # which is a sub-class of PIL.Image.Image. this will fail check_transform() test.
    return PIL.Image.open(raw_jpeg).copy()