functional_tensor.py 3.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torchvision.transforms.functional as F


def vflip(img_tensor):
    """Vertically flip the given the Image Tensor.

    Args:
        img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].

    Returns:
        Tensor:  Vertically flipped image Tensor.
    """
    if not F._is_tensor_image(img_tensor):
        raise TypeError('tensor is not a torch image.')

    return img_tensor.flip(-2)


def hflip(img_tensor):
    """Horizontally flip the given the Image Tensor.

    Args:
        img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].

    Returns:
        Tensor:  Horizontally flipped image Tensor.
    """
    if not F._is_tensor_image(img_tensor):
        raise TypeError('tensor is not a torch image.')

    return img_tensor.flip(-1)
ekka's avatar
ekka committed
33
34
35
36


def crop(img, top, left, height, width):
    """Crop the given Image Tensor.
37

ekka's avatar
ekka committed
38
39
40
41
42
43
    Args:
        img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image.
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
44

ekka's avatar
ekka committed
45
46
47
48
49
50
51
    Returns:
        Tensor: Cropped image.
    """
    if not F._is_tensor_image(img):
        raise TypeError('tensor is not a torch image.')

    return img[..., top:top + height, left:left + width]
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def rgb_to_grayscale(img):
    """Convert the given RGB Image Tensor to Grayscale.
    For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
    is L = R * 0.2989 + G * 0.5870 + B * 0.1140

    Args:
        img (Tensor): Image to be converted to Grayscale in the form [C, H, W].

    Returns:
        Tensor: Grayscale image.

    """
    if img.shape[0] != 3:
        raise TypeError('Input Image does not contain 3 Channels')

    return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def adjust_brightness(img, brightness_factor):
    """Adjust brightness of an RGB image.

    Args:
        img (Tensor): Image to be adjusted.
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
        Tensor: Brightness adjusted image.
    """
    if not F._is_tensor_image(img):
        raise TypeError('tensor is not a torch image.')

    return _blend(img, 0, brightness_factor)


def adjust_contrast(img, contrast_factor):
    """Adjust contrast of an RGB image.

    Args:
        img (Tensor): Image to be adjusted.
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
        Tensor: Contrast adjusted image.
    """
    if not F._is_tensor_image(img):
        raise TypeError('tensor is not a torch image.')

105
    mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    return _blend(img, mean, contrast_factor)


def adjust_saturation(img, saturation_factor):
    """Adjust color saturation of an RGB image.

    Args:
        img (Tensor): Image to be adjusted.
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
        Tensor: Saturation adjusted image.
    """
    if not F._is_tensor_image(img):
        raise TypeError('tensor is not a torch image.')

125
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
126
127
128
129
130


def _blend(img1, img2, ratio):
    bound = 1 if img1.dtype.is_floating_point else 255
    return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)