_augment.py 3.12 KB
Newer Older
Thien Tran's avatar
Thien Tran committed
1
2
import io

3
4
5
import PIL.Image

import torch
6
from torchvision import tv_tensors
Thien Tran's avatar
Thien Tran committed
7
from torchvision.io import decode_jpeg, encode_jpeg
8
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
9
from torchvision.utils import _log_api_usage_once
10

11
from ._utils import _get_kernel, _register_kernel_internal
12

13

14
def erase(
15
    inpt: torch.Tensor,
16
17
18
19
20
21
    i: int,
    j: int,
    h: int,
    w: int,
    v: torch.Tensor,
    inplace: bool = False,
22
) -> torch.Tensor:
23
    """See :class:`~torchvision.transforms.v2.RandomErase` for details."""
24
    if torch.jit.is_scripting():
25
        return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
26
27
28
29
30

    _log_api_usage_once(erase)

    kernel = _get_kernel(erase, type(inpt))
    return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
31
32


33
@_register_kernel_internal(erase, torch.Tensor)
34
@_register_kernel_internal(erase, tv_tensors.Image)
35
def erase_image(
36
37
38
39
40
41
42
    image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
    if not inplace:
        image = image.clone()

    image[..., i : i + h, j : j + w] = v
    return image
43
44


45
@_register_kernel_internal(erase, PIL.Image.Image)
46
def _erase_image_pil(
47
    image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
48
) -> PIL.Image.Image:
49
    t_img = pil_to_tensor(image)
50
    output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
51
    return to_pil_image(output, mode=image.mode)
52
53


54
@_register_kernel_internal(erase, tv_tensors.Video)
55
56
57
def erase_video(
    video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
58
    return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
Thien Tran's avatar
Thien Tran committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
    """See :class:`~torchvision.transforms.v2.JPEG` for details."""
    if torch.jit.is_scripting():
        return jpeg_image(image, quality=quality)

    _log_api_usage_once(jpeg)

    kernel = _get_kernel(jpeg, type(image))
    return kernel(image, quality=quality)


@_register_kernel_internal(jpeg, torch.Tensor)
@_register_kernel_internal(jpeg, tv_tensors.Image)
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
    original_shape = image.shape
    image = image.view((-1,) + image.shape[-3:])

    if image.shape[0] == 0:  # degenerate
        return image.reshape(original_shape).clone()

    image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])]
    image = torch.stack(image, dim=0).view(original_shape)
    return image


@_register_kernel_internal(jpeg, tv_tensors.Video)
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
    return jpeg_image(video, quality=quality)


@_register_kernel_internal(jpeg, PIL.Image.Image)
def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image:
    raw_jpeg = io.BytesIO()
    image.save(raw_jpeg, format="JPEG", quality=quality)

    # we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile
    # which is a sub-class of PIL.Image.Image. this will fail check_transform() test.
    return PIL.Image.open(raw_jpeg).copy()