_augment.py 1.82 KB
Newer Older
1
2
from typing import Union

3
4
5
import PIL.Image

import torch
6
from torchvision import datapoints
7
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
8
from torchvision.utils import _log_api_usage_once
9

10
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
11

12

13
14
15
16
17
18
19
20
21
22
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase(
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
    i: int,
    j: int,
    h: int,
    w: int,
    v: torch.Tensor,
    inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
23
    if torch.jit.is_scripting():
24
        return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
25
26
27
28
29

    _log_api_usage_once(erase)

    kernel = _get_kernel(erase, type(inpt))
    return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
30
31


32
@_register_kernel_internal(erase, torch.Tensor)
33
@_register_kernel_internal(erase, datapoints.Image)
34
35
36
37
38
39
40
41
def erase_image_tensor(
    image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
    if not inplace:
        image = image.clone()

    image[..., i : i + h, j : j + w] = v
    return image
42
43


44
@_register_kernel_internal(erase, PIL.Image.Image)
45
def erase_image_pil(
46
    image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
47
) -> PIL.Image.Image:
48
    t_img = pil_to_tensor(image)
49
    output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
50
    return to_pil_image(output, mode=image.mode)
51
52


53
@_register_kernel_internal(erase, datapoints.Video)
54
55
56
57
def erase_video(
    video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
    return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)