image.py 12.4 KB
Newer Older
1
from enum import Enum
2
from typing import List, Union
3
from warnings import warn
4

5
6
import torch

7
from ..extension import _load_library
Kai Zhang's avatar
Kai Zhang committed
8
from ..utils import _log_api_usage_once
9
10


11
try:
12
13
    _load_library("image")
except (ImportError, OSError) as e:
14
15
16
17
18
19
    warn(
        f"Failed to load image Python extension: '{e}'"
        f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
        f"Otherwise, there might be something wrong with your environment. "
        f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
    )
20
21


22
class ImageReadMode(Enum):
23
24
25
    """
    Support for various modes while reading images.

26
27
28
29
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
30
31
    RGB with transparency.
    """
32

33
34
35
36
37
38
39
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
40
41
42
43
44
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

45
    Args:
46
        path (str or ``pathlib.Path``): the path to the file to be read
Francisco Massa's avatar
Francisco Massa committed
47
48
49
50

    Returns:
        data (Tensor)
    """
Kai Zhang's avatar
Kai Zhang committed
51
52
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_file)
53
    data = torch.ops.image.read_file(str(path))
Francisco Massa's avatar
Francisco Massa committed
54
55
56
    return data


Francisco Massa's avatar
Francisco Massa committed
57
58
def write_file(filename: str, data: torch.Tensor) -> None:
    """
59
    Writes the contents of an uint8 tensor with one dimension to a
Francisco Massa's avatar
Francisco Massa committed
60
61
    file.

62
    Args:
63
        filename (str or ``pathlib.Path``): the path to the file to be written
Francisco Massa's avatar
Francisco Massa committed
64
65
        data (Tensor): the contents to be written to the output file
    """
Kai Zhang's avatar
Kai Zhang committed
66
67
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_file)
68
    torch.ops.image.write_file(str(filename), data)
Francisco Massa's avatar
Francisco Massa committed
69
70


71
def decode_png(
72
73
74
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    apply_exif_orientation: bool = False,
75
) -> torch.Tensor:
76
    """
77
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
78
    Optionally converts the image to the desired format.
79
    The values of the output tensor are uint8 in [0, 255].
80

81
    Args:
Francisco Massa's avatar
Francisco Massa committed
82
        input (Tensor[1]): a one dimensional uint8 tensor containing
83
            the raw bytes of the PNG image.
84
        mode (ImageReadMode): the read mode used for optionally
85
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
86
87
            See `ImageReadMode` class for more information on various
            available modes.
88
89
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False.
90
91

    Returns:
92
        output (Tensor[image_channels, image_height, image_width])
93
    """
Kai Zhang's avatar
Kai Zhang committed
94
95
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_png)
96
    output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
97
98
99
    return output


100
101
102
103
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
104

105
106
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
107
            ``c`` channels, where ``c`` must 3 or 1.
108
109
110
111
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
112
113
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
114
    """
Kai Zhang's avatar
Kai Zhang committed
115
116
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_png)
117
118
119
120
121
122
123
124
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
125

126
127
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
128
            ``c`` channels, where ``c`` must be 1 or 3.
129
        filename (str or ``pathlib.Path``): Path to save the image.
130
131
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
132
    """
Kai Zhang's avatar
Kai Zhang committed
133
134
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_png)
135
136
    output = encode_png(input, compression_level)
    write_file(filename, output)
137
138


139
def decode_jpeg(
140
141
142
143
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    device: str = "cpu",
    apply_exif_orientation: bool = False,
144
) -> torch.Tensor:
145
    """
146
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
147
    Optionally converts the image to the desired format.
148
    The values of the output tensor are uint8 between 0 and 255.
149

150
    Args:
Francisco Massa's avatar
Francisco Massa committed
151
        input (Tensor[1]): a one dimensional uint8 tensor containing
152
153
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
154
        mode (ImageReadMode): the read mode used for optionally
155
156
157
            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
            Default: ``ImageReadMode.UNCHANGED``.
158
            See ``ImageReadMode`` class for more information on various
159
            available modes.
160
161
162
163
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
164

165
166
            .. betastatus:: device parameter

167
168
169
            .. warning::
                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
170
171
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False. Only implemented for JPEG format on CPU.
172

173
    Returns:
174
        output (Tensor[image_channels, image_height, image_width])
175
    """
Kai Zhang's avatar
Kai Zhang committed
176
177
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_jpeg)
178
    device = torch.device(device)
179
    if device.type == "cuda":
180
181
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
182
        output = torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
183
184
185
    return output


186
187
188
def encode_jpeg(
    input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75
) -> Union[torch.Tensor, List[torch.Tensor]]:
189
    """
190
191
192
193
194
195
    Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents
    of the corresponding JPEG file(s).

    .. note::
        Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
        For CPU tensors the performance is equivalent.
196

197
    Args:
198
199
200
        input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
            (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
        quality (int): Quality of the resulting JPEG file(s). Must be a number between
201
202
203
            1 and 100. Default: 75

    Returns:
204
        output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
205
    """
Kai Zhang's avatar
Kai Zhang committed
206
207
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_jpeg)
208
    if quality < 1 or quality > 100:
209
        raise ValueError("Image quality should be a positive number between 1 and 100")
210
211
212
213
214
215
216
217
218
219
220
221
    if isinstance(input, list):
        if not input:
            raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
        if input[0].device.type == "cuda":
            return torch.ops.image.encode_jpegs_cuda(input, quality)
        else:
            return [torch.ops.image.encode_jpeg(image, quality) for image in input]
    else:  # single input tensor
        if input.device.type == "cuda":
            return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
        else:
            return torch.ops.image.encode_jpeg(input, quality)
222
223
224
225


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
226
    Takes an input tensor in CHW layout and saves it in a JPEG file.
227

228
    Args:
229
230
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
231
        filename (str or ``pathlib.Path``): Path to save the image.
232
233
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
234
    """
Kai Zhang's avatar
Kai Zhang committed
235
236
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_jpeg)
237
    output = encode_jpeg(input, quality)
238
    assert isinstance(output, torch.Tensor)  # Needed for torchscript
239
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
240
241


242
def decode_image(
243
244
245
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    apply_exif_orientation: bool = False,
246
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
247
    """
Nicolas Hug's avatar
Nicolas Hug committed
248
    Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
249
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
250

251
    Optionally converts the image to the desired format.
252
    The values of the output tensor are uint8 in [0, 255].
Francisco Massa's avatar
Francisco Massa committed
253

254
255
256
257
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
258
259
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
Nicolas Hug's avatar
Nicolas Hug committed
260
            available modes. Ignored for GIFs.
261
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Nicolas Hug's avatar
Nicolas Hug committed
262
            Ignored for GIFs. Default: False.
263
264
265

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
266
    """
Kai Zhang's avatar
Kai Zhang committed
267
268
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_image)
269
    output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
Francisco Massa's avatar
Francisco Massa committed
270
271
272
    return output


273
def read_image(
274
275
276
    path: str,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    apply_exif_orientation: bool = False,
277
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
278
    """
Nicolas Hug's avatar
Nicolas Hug committed
279
    Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
280
    Optionally converts the image to the desired format.
281
    The values of the output tensor are uint8 in [0, 255].
282

283
    Args:
Fangjun Kuang's avatar
Fangjun Kuang committed
284
        path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image.
285
        mode (ImageReadMode): the read mode used for optionally converting the image.
286
287
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
Nicolas Hug's avatar
Nicolas Hug committed
288
            available modes. Ignored for GIFs.
289
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Nicolas Hug's avatar
Nicolas Hug committed
290
            Ignored for GIFs. Default: False.
291
292
293

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
294
    """
Kai Zhang's avatar
Kai Zhang committed
295
296
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_image)
Francisco Massa's avatar
Francisco Massa committed
297
    data = read_file(path)
298
    return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
299
300
301
302
303


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
    data = read_file(path)
    return torch.ops.image.decode_png(data, mode.value, True)
Nicolas Hug's avatar
Nicolas Hug committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323


def decode_gif(input: torch.Tensor) -> torch.Tensor:
    """
    Decode a GIF image into a 3 or 4 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.
    The output tensor has shape ``(C, H, W)`` if there is only one image in the
    GIF, and ``(N, C, H, W)`` if there are ``N`` images.

    Args:
        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
            the raw bytes of the GIF image.

    Returns:
        output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_gif)
    return torch.ops.image.decode_gif(input)