image.py 12.4 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
from enum import Enum
from typing import List, Union
from warnings import warn

5
6
import torch

limm's avatar
limm committed
7
8
from ..extension import _load_library
from ..utils import _log_api_usage_once
9
10
11


try:
limm's avatar
limm committed
12
13
14
15
16
17
18
    _load_library("image")
except (ImportError, OSError) as e:
    warn(
        f"Failed to load image Python extension: '{e}'"
        f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
        f"Otherwise, there might be something wrong with your environment. "
        f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
19
20
21
    )


22
class ImageReadMode(Enum):
23
24
25
    """
    Support for various modes while reading images.

26
27
28
29
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
30
31
    RGB with transparency.
    """
limm's avatar
limm committed
32

33
34
35
36
37
38
39
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
40
41
42
43
44
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

45
    Args:
limm's avatar
limm committed
46
        path (str or ``pathlib.Path``): the path to the file to be read
Francisco Massa's avatar
Francisco Massa committed
47
48
49
50

    Returns:
        data (Tensor)
    """
limm's avatar
limm committed
51
52
53
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_file)
    data = torch.ops.image.read_file(str(path))
Francisco Massa's avatar
Francisco Massa committed
54
55
56
    return data


Francisco Massa's avatar
Francisco Massa committed
57
58
def write_file(filename: str, data: torch.Tensor) -> None:
    """
limm's avatar
limm committed
59
    Writes the contents of an uint8 tensor with one dimension to a
Francisco Massa's avatar
Francisco Massa committed
60
61
    file.

62
    Args:
limm's avatar
limm committed
63
        filename (str or ``pathlib.Path``): the path to the file to be written
Francisco Massa's avatar
Francisco Massa committed
64
65
        data (Tensor): the contents to be written to the output file
    """
limm's avatar
limm committed
66
67
68
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_file)
    torch.ops.image.write_file(str(filename), data)
Francisco Massa's avatar
Francisco Massa committed
69
70


limm's avatar
limm committed
71
72
73
74
75
def decode_png(
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    apply_exif_orientation: bool = False,
) -> torch.Tensor:
76
    """
limm's avatar
limm committed
77
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
78
    Optionally converts the image to the desired format.
limm's avatar
limm committed
79
    The values of the output tensor are uint8 in [0, 255].
80

81
    Args:
Francisco Massa's avatar
Francisco Massa committed
82
        input (Tensor[1]): a one dimensional uint8 tensor containing
83
            the raw bytes of the PNG image.
84
        mode (ImageReadMode): the read mode used for optionally
85
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
86
87
            See `ImageReadMode` class for more information on various
            available modes.
limm's avatar
limm committed
88
89
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False.
90
91

    Returns:
92
        output (Tensor[image_channels, image_height, image_width])
93
    """
limm's avatar
limm committed
94
95
96
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_png)
    output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
97
98
99
    return output


100
101
102
103
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
104

105
106
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
107
            ``c`` channels, where ``c`` must 3 or 1.
108
109
110
111
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
112
113
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
114
    """
limm's avatar
limm committed
115
116
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_png)
117
118
119
120
121
122
123
124
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
125

126
127
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
128
            ``c`` channels, where ``c`` must be 1 or 3.
limm's avatar
limm committed
129
        filename (str or ``pathlib.Path``): Path to save the image.
130
131
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
132
    """
limm's avatar
limm committed
133
134
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_png)
135
136
    output = encode_png(input, compression_level)
    write_file(filename, output)
137
138


limm's avatar
limm committed
139
140
141
142
143
144
def decode_jpeg(
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    device: str = "cpu",
    apply_exif_orientation: bool = False,
) -> torch.Tensor:
145
    """
limm's avatar
limm committed
146
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
147
    Optionally converts the image to the desired format.
148
    The values of the output tensor are uint8 between 0 and 255.
149

150
    Args:
Francisco Massa's avatar
Francisco Massa committed
151
        input (Tensor[1]): a one dimensional uint8 tensor containing
152
153
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
154
        mode (ImageReadMode): the read mode used for optionally
limm's avatar
limm committed
155
156
157
            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
            Default: ``ImageReadMode.UNCHANGED``.
158
            See ``ImageReadMode`` class for more information on various
159
            available modes.
160
161
162
163
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
164

limm's avatar
limm committed
165
166
167
168
169
170
171
172
            .. betastatus:: device parameter

            .. warning::
                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False. Only implemented for JPEG format on CPU.

173
    Returns:
174
        output (Tensor[image_channels, image_height, image_width])
175
    """
limm's avatar
limm committed
176
177
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_jpeg)
178
    device = torch.device(device)
limm's avatar
limm committed
179
    if device.type == "cuda":
180
181
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
limm's avatar
limm committed
182
        output = torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
183
184
185
    return output


limm's avatar
limm committed
186
187
188
def encode_jpeg(
    input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75
) -> Union[torch.Tensor, List[torch.Tensor]]:
189
    """
limm's avatar
limm committed
190
191
192
193
194
195
    Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents
    of the corresponding JPEG file(s).

    .. note::
        Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
        For CPU tensors the performance is equivalent.
196

197
    Args:
limm's avatar
limm committed
198
199
200
        input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
            (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
        quality (int): Quality of the resulting JPEG file(s). Must be a number between
201
202
203
            1 and 100. Default: 75

    Returns:
limm's avatar
limm committed
204
        output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
205
    """
limm's avatar
limm committed
206
207
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_jpeg)
208
    if quality < 1 or quality > 100:
limm's avatar
limm committed
209
210
211
212
213
214
215
216
217
218
219
220
221
        raise ValueError("Image quality should be a positive number between 1 and 100")
    if isinstance(input, list):
        if not input:
            raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
        if input[0].device.type == "cuda":
            return torch.ops.image.encode_jpegs_cuda(input, quality)
        else:
            return [torch.ops.image.encode_jpeg(image, quality) for image in input]
    else:  # single input tensor
        if input.device.type == "cuda":
            return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
        else:
            return torch.ops.image.encode_jpeg(input, quality)
222
223
224
225


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
226
    Takes an input tensor in CHW layout and saves it in a JPEG file.
227

228
    Args:
229
230
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
limm's avatar
limm committed
231
        filename (str or ``pathlib.Path``): Path to save the image.
232
233
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
234
    """
limm's avatar
limm committed
235
236
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_jpeg)
237
    output = encode_jpeg(input, quality)
limm's avatar
limm committed
238
    assert isinstance(output, torch.Tensor)  # Needed for torchscript
239
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
240
241


limm's avatar
limm committed
242
243
244
245
246
def decode_image(
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    apply_exif_orientation: bool = False,
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
247
    """
limm's avatar
limm committed
248
249
    Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
250

251
    Optionally converts the image to the desired format.
limm's avatar
limm committed
252
    The values of the output tensor are uint8 in [0, 255].
Francisco Massa's avatar
Francisco Massa committed
253

254
255
256
257
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
258
259
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
limm's avatar
limm committed
260
261
262
            available modes. Ignored for GIFs.
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Ignored for GIFs. Default: False.
263
264
265

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
266
    """
limm's avatar
limm committed
267
268
269
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_image)
    output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
Francisco Massa's avatar
Francisco Massa committed
270
271
272
    return output


limm's avatar
limm committed
273
274
275
276
277
def read_image(
    path: str,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    apply_exif_orientation: bool = False,
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
278
    """
limm's avatar
limm committed
279
    Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
280
    Optionally converts the image to the desired format.
limm's avatar
limm committed
281
    The values of the output tensor are uint8 in [0, 255].
282

283
    Args:
limm's avatar
limm committed
284
        path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image.
285
        mode (ImageReadMode): the read mode used for optionally converting the image.
286
287
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
limm's avatar
limm committed
288
289
290
            available modes. Ignored for GIFs.
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Ignored for GIFs. Default: False.
291
292
293

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
294
    """
limm's avatar
limm committed
295
296
297
298
299
300
301
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_image)
    data = read_file(path)
    return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
302
    data = read_file(path)
limm's avatar
limm committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    return torch.ops.image.decode_png(data, mode.value, True)


def decode_gif(input: torch.Tensor) -> torch.Tensor:
    """
    Decode a GIF image into a 3 or 4 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.
    The output tensor has shape ``(C, H, W)`` if there is only one image in the
    GIF, and ``(N, C, H, W)`` if there are ``N`` images.

    Args:
        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
            the raw bytes of the GIF image.

    Returns:
        output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_gif)
    return torch.ops.image.decode_gif(input)