image.py 11.4 KB
Newer Older
1
from enum import Enum
2
from warnings import warn
3

4
5
import torch

6
from ..extension import _load_library
Kai Zhang's avatar
Kai Zhang committed
7
from ..utils import _log_api_usage_once
8
9


10
try:
11
12
    _load_library("image")
except (ImportError, OSError) as e:
13
14
15
16
17
18
    warn(
        f"Failed to load image Python extension: '{e}'"
        f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
        f"Otherwise, there might be something wrong with your environment. "
        f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
    )
19
20


21
class ImageReadMode(Enum):
22
23
24
    """
    Support for various modes while reading images.

25
26
27
28
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
29
30
    RGB with transparency.
    """
31

32
33
34
35
36
37
38
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
39
40
41
42
43
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

44
    Args:
45
        path (str or ``pathlib.Path``): the path to the file to be read
Francisco Massa's avatar
Francisco Massa committed
46
47
48
49

    Returns:
        data (Tensor)
    """
Kai Zhang's avatar
Kai Zhang committed
50
51
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_file)
52
    data = torch.ops.image.read_file(str(path))
Francisco Massa's avatar
Francisco Massa committed
53
54
55
    return data


Francisco Massa's avatar
Francisco Massa committed
56
57
def write_file(filename: str, data: torch.Tensor) -> None:
    """
58
    Writes the contents of an uint8 tensor with one dimension to a
Francisco Massa's avatar
Francisco Massa committed
59
60
    file.

61
    Args:
62
        filename (str or ``pathlib.Path``): the path to the file to be written
Francisco Massa's avatar
Francisco Massa committed
63
64
        data (Tensor): the contents to be written to the output file
    """
Kai Zhang's avatar
Kai Zhang committed
65
66
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_file)
67
    torch.ops.image.write_file(str(filename), data)
Francisco Massa's avatar
Francisco Massa committed
68
69


70
71
72
def decode_png(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
73
    """
74
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
75
    Optionally converts the image to the desired format.
76
    The values of the output tensor are uint8 in [0, 255].
77

78
    Args:
Francisco Massa's avatar
Francisco Massa committed
79
        input (Tensor[1]): a one dimensional uint8 tensor containing
80
            the raw bytes of the PNG image.
81
        mode (ImageReadMode): the read mode used for optionally
82
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
83
84
            See `ImageReadMode` class for more information on various
            available modes.
85
86
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False.
87
88

    Returns:
89
        output (Tensor[image_channels, image_height, image_width])
90
    """
Kai Zhang's avatar
Kai Zhang committed
91
92
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_png)
93
    output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
94
95
96
    return output


97
98
99
100
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
101

102
103
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
104
            ``c`` channels, where ``c`` must 3 or 1.
105
106
107
108
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
109
110
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
111
    """
Kai Zhang's avatar
Kai Zhang committed
112
113
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_png)
114
115
116
117
118
119
120
121
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
122

123
124
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
125
            ``c`` channels, where ``c`` must be 1 or 3.
126
        filename (str or ``pathlib.Path``): Path to save the image.
127
128
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
129
    """
Kai Zhang's avatar
Kai Zhang committed
130
131
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_png)
132
133
    output = encode_png(input, compression_level)
    write_file(filename, output)
134
135


136
def decode_jpeg(
137
138
139
140
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    device: str = "cpu",
    apply_exif_orientation: bool = False,
141
) -> torch.Tensor:
142
    """
143
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
144
    Optionally converts the image to the desired format.
145
    The values of the output tensor are uint8 between 0 and 255.
146

147
    Args:
Francisco Massa's avatar
Francisco Massa committed
148
        input (Tensor[1]): a one dimensional uint8 tensor containing
149
150
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
151
        mode (ImageReadMode): the read mode used for optionally
152
153
154
            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
            Default: ``ImageReadMode.UNCHANGED``.
155
            See ``ImageReadMode`` class for more information on various
156
            available modes.
157
158
159
160
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
161

162
163
            .. betastatus:: device parameter

164
165
166
            .. warning::
                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
167
168
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False. Only implemented for JPEG format on CPU.
169

170
    Returns:
171
        output (Tensor[image_channels, image_height, image_width])
172
    """
Kai Zhang's avatar
Kai Zhang committed
173
174
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_jpeg)
175
    device = torch.device(device)
176
    if device.type == "cuda":
177
178
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
179
        output = torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
180
181
182
    return output


183
184
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
185
186
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
187

188
189
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
190
            ``c`` channels, where ``c`` must be 1 or 3.
191
192
193
194
195
196
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
197
    """
Kai Zhang's avatar
Kai Zhang committed
198
199
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_jpeg)
200
    if quality < 1 or quality > 100:
201
        raise ValueError("Image quality should be a positive number between 1 and 100")
202
203
204
205
206
207
208

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
209
    Takes an input tensor in CHW layout and saves it in a JPEG file.
210

211
    Args:
212
213
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
214
        filename (str or ``pathlib.Path``): Path to save the image.
215
216
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
217
    """
Kai Zhang's avatar
Kai Zhang committed
218
219
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_jpeg)
220
221
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
222
223


224
225
226
def decode_image(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
227
    """
Nicolas Hug's avatar
Nicolas Hug committed
228
    Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
229
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
230

231
    Optionally converts the image to the desired format.
232
    The values of the output tensor are uint8 in [0, 255].
Francisco Massa's avatar
Francisco Massa committed
233

234
235
236
237
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
238
239
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
Nicolas Hug's avatar
Nicolas Hug committed
240
            available modes. Ignored for GIFs.
241
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Nicolas Hug's avatar
Nicolas Hug committed
242
            Ignored for GIFs. Default: False.
243
244
245

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
246
    """
Kai Zhang's avatar
Kai Zhang committed
247
248
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_image)
249
    output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
Francisco Massa's avatar
Francisco Massa committed
250
251
252
    return output


253
254
255
def read_image(
    path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
256
    """
Nicolas Hug's avatar
Nicolas Hug committed
257
    Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
258
    Optionally converts the image to the desired format.
259
    The values of the output tensor are uint8 in [0, 255].
260

261
    Args:
262
        path (str or ``pathlib.Path``): path of the JPEG or PNG image.
263
        mode (ImageReadMode): the read mode used for optionally converting the image.
264
265
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
Nicolas Hug's avatar
Nicolas Hug committed
266
            available modes. Ignored for GIFs.
267
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Nicolas Hug's avatar
Nicolas Hug committed
268
            Ignored for GIFs. Default: False.
269
270
271

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
272
    """
Kai Zhang's avatar
Kai Zhang committed
273
274
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_image)
Francisco Massa's avatar
Francisco Massa committed
275
    data = read_file(path)
276
    return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
277
278
279
280
281


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
    data = read_file(path)
    return torch.ops.image.decode_png(data, mode.value, True)
Nicolas Hug's avatar
Nicolas Hug committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301


def decode_gif(input: torch.Tensor) -> torch.Tensor:
    """
    Decode a GIF image into a 3 or 4 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.
    The output tensor has shape ``(C, H, W)`` if there is only one image in the
    GIF, and ``(N, C, H, W)`` if there are ``N`` images.

    Args:
        input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
            the raw bytes of the GIF image.

    Returns:
        output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_gif)
    return torch.ops.image.decode_gif(input)