image.py 7.7 KB
Newer Older
1
import torch
2
3
from enum import Enum

4
from .._register_extension import _get_extension_path
5
6


7
8
9
try:
    lib_path = _get_extension_path('image')
    torch.ops.load_library(lib_path)
10
11
12
13
except (ImportError, OSError):
    pass


14
class ImageReadMode(Enum):
15
16
17
    """
    Support for various modes while reading images.

18
19
20
21
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
22
23
    RGB with transparency.
    """
24
25
26
27
28
29
30
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
31
32
33
34
35
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

36
    Args:
Francisco Massa's avatar
Francisco Massa committed
37
38
39
40
41
42
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
43
44
45
    return data


Francisco Massa's avatar
Francisco Massa committed
46
47
48
49
50
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

51
    Args:
Francisco Massa's avatar
Francisco Massa committed
52
53
54
55
56
57
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


58
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
59
60
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
61
    Optionally converts the image to the desired format.
62
63
    The values of the output tensor are uint8 between 0 and 255.

64
    Args:
Francisco Massa's avatar
Francisco Massa committed
65
        input (Tensor[1]): a one dimensional uint8 tensor containing
66
            the raw bytes of the PNG image.
67
        mode (ImageReadMode): the read mode used for optionally
68
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
69
70
            See `ImageReadMode` class for more information on various
            available modes.
71
72

    Returns:
73
        output (Tensor[image_channels, image_height, image_width])
74
    """
75
    output = torch.ops.image.decode_png(input, mode.value)
76
77
78
    return output


79
80
81
82
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
83

84
85
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
86
            ``c`` channels, where ``c`` must 3 or 1.
87
88
89
90
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
91
92
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
93
94
95
96
97
98
99
100
101
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
102

103
104
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
105
            ``c`` channels, where ``c`` must be 1 or 3.
106
107
108
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
109
    """
110
111
    output = encode_png(input, compression_level)
    write_file(filename, output)
112
113


114
115
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
                device: str = 'cpu') -> torch.Tensor:
116
117
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
118
    Optionally converts the image to the desired format.
119
    The values of the output tensor are uint8 between 0 and 255.
120

121
    Args:
Francisco Massa's avatar
Francisco Massa committed
122
        input (Tensor[1]): a one dimensional uint8 tensor containing
123
124
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
125
        mode (ImageReadMode): the read mode used for optionally
126
127
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
128
            available modes.
129
130
131
132
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
133

134
    Returns:
135
        output (Tensor[image_channels, image_height, image_width])
136
    """
137
138
139
140
141
    device = torch.device(device)
    if device.type == 'cuda':
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)
142
143
144
    return output


145
146
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
147
148
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
149

150
151
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
152
            ``c`` channels, where ``c`` must be 1 or 3.
153
154
155
156
157
158
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
159
160
161
162
163
164
165
166
167
168
169
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
170
    Takes an input tensor in CHW layout and saves it in a JPEG file.
171

172
    Args:
173
174
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
175
176
177
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
178
    """
179
180
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
181
182


183
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
184
185
186
187
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

188
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
189
190
    The values of the output tensor are uint8 between 0 and 255.

191
192
193
194
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
195
196
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
197
198
199
200
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
201
    """
202
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
203
204
205
    return output


206
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
207
208
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
209
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
210
    The values of the output tensor are uint8 between 0 and 255.
211

212
213
214
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
215
216
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
217
218
219
220
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
221
    """
Francisco Massa's avatar
Francisco Massa committed
222
    data = read_file(path)
223
    return decode_image(data, mode)