image.py 10.4 KB
Newer Older
1
from enum import Enum
2
from warnings import warn
3

4
5
import torch

6
from ..extension import _load_library
Kai Zhang's avatar
Kai Zhang committed
7
from ..utils import _log_api_usage_once
8
9


10
try:
11
12
    _load_library("image")
except (ImportError, OSError) as e:
13
14
15
16
17
18
    warn(
        f"Failed to load image Python extension: '{e}'"
        f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
        f"Otherwise, there might be something wrong with your environment. "
        f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
    )
19
20


21
class ImageReadMode(Enum):
22
23
24
    """
    Support for various modes while reading images.

25
26
27
28
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
29
30
    RGB with transparency.
    """
31

32
33
34
35
36
37
38
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
39
40
41
42
43
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

44
    Args:
Francisco Massa's avatar
Francisco Massa committed
45
46
47
48
49
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
Kai Zhang's avatar
Kai Zhang committed
50
51
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_file)
Francisco Massa's avatar
Francisco Massa committed
52
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
53
54
55
    return data


Francisco Massa's avatar
Francisco Massa committed
56
57
def write_file(filename: str, data: torch.Tensor) -> None:
    """
58
    Writes the contents of an uint8 tensor with one dimension to a
Francisco Massa's avatar
Francisco Massa committed
59
60
    file.

61
    Args:
Francisco Massa's avatar
Francisco Massa committed
62
63
64
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
Kai Zhang's avatar
Kai Zhang committed
65
66
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_file)
Francisco Massa's avatar
Francisco Massa committed
67
68
69
    torch.ops.image.write_file(filename, data)


70
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
71
    """
72
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
73
    Optionally converts the image to the desired format.
74
    The values of the output tensor are uint8 in [0, 255].
75

76
    Args:
Francisco Massa's avatar
Francisco Massa committed
77
        input (Tensor[1]): a one dimensional uint8 tensor containing
78
            the raw bytes of the PNG image.
79
        mode (ImageReadMode): the read mode used for optionally
80
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
81
82
            See `ImageReadMode` class for more information on various
            available modes.
83
84

    Returns:
85
        output (Tensor[image_channels, image_height, image_width])
86
    """
Kai Zhang's avatar
Kai Zhang committed
87
88
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_png)
89
    output = torch.ops.image.decode_png(input, mode.value, False)
90
91
92
    return output


93
94
95
96
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
97

98
99
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
100
            ``c`` channels, where ``c`` must 3 or 1.
101
102
103
104
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
105
106
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
107
    """
Kai Zhang's avatar
Kai Zhang committed
108
109
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_png)
110
111
112
113
114
115
116
117
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
118

119
120
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
121
            ``c`` channels, where ``c`` must be 1 or 3.
122
123
124
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
125
    """
Kai Zhang's avatar
Kai Zhang committed
126
127
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_png)
128
129
    output = encode_png(input, compression_level)
    write_file(filename, output)
130
131


132
def decode_jpeg(
133
134
135
136
    input: torch.Tensor,
    mode: ImageReadMode = ImageReadMode.UNCHANGED,
    device: str = "cpu",
    apply_exif_orientation: bool = False,
137
) -> torch.Tensor:
138
    """
139
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
140
    Optionally converts the image to the desired format.
141
    The values of the output tensor are uint8 between 0 and 255.
142

143
    Args:
Francisco Massa's avatar
Francisco Massa committed
144
        input (Tensor[1]): a one dimensional uint8 tensor containing
145
146
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
147
        mode (ImageReadMode): the read mode used for optionally
148
149
150
            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
            Default: ``ImageReadMode.UNCHANGED``.
151
            See ``ImageReadMode`` class for more information on various
152
            available modes.
153
154
155
156
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
157

158
159
            .. betastatus:: device parameter

160
161
162
            .. warning::
                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
163
164
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False. Only implemented for JPEG format on CPU.
165

166
    Returns:
167
        output (Tensor[image_channels, image_height, image_width])
168
    """
Kai Zhang's avatar
Kai Zhang committed
169
170
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_jpeg)
171
    device = torch.device(device)
172
    if device.type == "cuda":
173
174
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
175
        output = torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
176
177
178
    return output


179
180
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
181
182
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
183

184
185
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
186
            ``c`` channels, where ``c`` must be 1 or 3.
187
188
189
190
191
192
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
193
    """
Kai Zhang's avatar
Kai Zhang committed
194
195
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_jpeg)
196
    if quality < 1 or quality > 100:
197
        raise ValueError("Image quality should be a positive number between 1 and 100")
198
199
200
201
202
203
204

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
205
    Takes an input tensor in CHW layout and saves it in a JPEG file.
206

207
    Args:
208
209
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
210
211
212
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
213
    """
Kai Zhang's avatar
Kai Zhang committed
214
215
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_jpeg)
216
217
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
218
219


220
221
222
def decode_image(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
223
224
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
225
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
226

227
    Optionally converts the image to the desired format.
228
    The values of the output tensor are uint8 in [0, 255].
Francisco Massa's avatar
Francisco Massa committed
229

230
231
232
233
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
234
235
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
236
            available modes.
237
238
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False. Only implemented for JPEG format
239
240
241

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
242
    """
Kai Zhang's avatar
Kai Zhang committed
243
244
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_image)
245
    output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
Francisco Massa's avatar
Francisco Massa committed
246
247
248
    return output


249
250
251
def read_image(
    path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
252
    """
253
    Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
254
    Optionally converts the image to the desired format.
255
    The values of the output tensor are uint8 in [0, 255].
256

257
258
259
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
260
261
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
262
            available modes.
263
264
        apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
            Default: False. Only implemented for JPEG format
265
266
267

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
268
    """
Kai Zhang's avatar
Kai Zhang committed
269
270
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_image)
Francisco Massa's avatar
Francisco Massa committed
271
    data = read_file(path)
272
    return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
273
274
275
276
277


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
    data = read_file(path)
    return torch.ops.image.decode_png(data, mode.value, True)