_augment.py 2.17 KB
Newer Older
1
2
from typing import Union

3
4
5
import PIL.Image

import torch
6
from torchvision import datapoints
7
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
8
from torchvision.utils import _log_api_usage_once
9

10
11
from ._utils import is_simple_tensor

12
13
14
15
16
17
18
19
20

def erase_image_tensor(
    image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
    if not inplace:
        image = image.clone()

    image[..., i : i + h, j : j + w] = v
    return image
21
22


23
@torch.jit.unused
24
def erase_image_pil(
25
    image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
26
) -> PIL.Image.Image:
27
    t_img = pil_to_tensor(image)
28
    output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
29
    return to_pil_image(output, mode=image.mode)
30
31


32
33
34
35
36
37
def erase_video(
    video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
    return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)


38
def erase(
Philip Meier's avatar
Philip Meier committed
39
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
40
41
42
43
44
45
    i: int,
    j: int,
    h: int,
    w: int,
    v: torch.Tensor,
    inplace: bool = False,
Philip Meier's avatar
Philip Meier committed
46
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
47
48
49
    if not torch.jit.is_scripting():
        _log_api_usage_once(erase)

50
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
51
        return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
52
    elif isinstance(inpt, datapoints.Image):
53
        output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
54
55
        return datapoints.Image.wrap_like(inpt, output)
    elif isinstance(inpt, datapoints.Video):
56
        output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
57
        return datapoints.Video.wrap_like(inpt, output)
58
    elif isinstance(inpt, PIL.Image.Image):
59
        return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
60
61
    else:
        raise TypeError(
62
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
63
64
            f"but got {type(inpt)} instead."
        )