image.py 9.35 KB
Newer Older
1
from enum import Enum
2
from warnings import warn
3

4
5
import torch

6
from ..extension import _load_library
Kai Zhang's avatar
Kai Zhang committed
7
from ..utils import _log_api_usage_once
8
9


10
try:
11
12
13
    _load_library("image")
except (ImportError, OSError) as e:
    warn(f"Failed to load image Python extension: {e}")
14
15


16
class ImageReadMode(Enum):
17
18
19
    """
    Support for various modes while reading images.

20
21
22
23
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
24
25
    RGB with transparency.
    """
26

27
28
29
30
31
32
33
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
34
35
36
37
38
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

39
    Args:
Francisco Massa's avatar
Francisco Massa committed
40
41
42
43
44
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
Kai Zhang's avatar
Kai Zhang committed
45
46
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_file)
Francisco Massa's avatar
Francisco Massa committed
47
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
48
49
50
    return data


Francisco Massa's avatar
Francisco Massa committed
51
52
def write_file(filename: str, data: torch.Tensor) -> None:
    """
53
    Writes the contents of an uint8 tensor with one dimension to a
Francisco Massa's avatar
Francisco Massa committed
54
55
    file.

56
    Args:
Francisco Massa's avatar
Francisco Massa committed
57
58
59
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
Kai Zhang's avatar
Kai Zhang committed
60
61
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_file)
Francisco Massa's avatar
Francisco Massa committed
62
63
64
    torch.ops.image.write_file(filename, data)


65
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
66
    """
67
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
68
    Optionally converts the image to the desired format.
69
    The values of the output tensor are uint8 in [0, 255].
70

71
    Args:
Francisco Massa's avatar
Francisco Massa committed
72
        input (Tensor[1]): a one dimensional uint8 tensor containing
73
            the raw bytes of the PNG image.
74
        mode (ImageReadMode): the read mode used for optionally
75
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
76
77
            See `ImageReadMode` class for more information on various
            available modes.
78
79

    Returns:
80
        output (Tensor[image_channels, image_height, image_width])
81
    """
Kai Zhang's avatar
Kai Zhang committed
82
83
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_png)
84
    output = torch.ops.image.decode_png(input, mode.value, False)
85
86
87
    return output


88
89
90
91
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
92

93
94
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
95
            ``c`` channels, where ``c`` must 3 or 1.
96
97
98
99
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
100
101
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
102
    """
Kai Zhang's avatar
Kai Zhang committed
103
104
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_png)
105
106
107
108
109
110
111
112
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
113

114
115
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
116
            ``c`` channels, where ``c`` must be 1 or 3.
117
118
119
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
120
    """
Kai Zhang's avatar
Kai Zhang committed
121
122
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_png)
123
124
    output = encode_png(input, compression_level)
    write_file(filename, output)
125
126


127
128
129
def decode_jpeg(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
130
    """
131
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
132
    Optionally converts the image to the desired format.
133
    The values of the output tensor are uint8 between 0 and 255.
134

135
    Args:
Francisco Massa's avatar
Francisco Massa committed
136
        input (Tensor[1]): a one dimensional uint8 tensor containing
137
138
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
139
        mode (ImageReadMode): the read mode used for optionally
140
141
142
            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
            Default: ``ImageReadMode.UNCHANGED``.
143
            See ``ImageReadMode`` class for more information on various
144
            available modes.
145
146
147
148
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
149

150
151
            .. betastatus:: device parameter

152
153
154
155
            .. warning::
                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.

156
    Returns:
157
        output (Tensor[image_channels, image_height, image_width])
158
    """
Kai Zhang's avatar
Kai Zhang committed
159
160
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_jpeg)
161
    device = torch.device(device)
162
    if device.type == "cuda":
163
164
165
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)
166
167
168
    return output


169
170
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
171
172
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
173

174
175
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
176
            ``c`` channels, where ``c`` must be 1 or 3.
177
178
179
180
181
182
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
183
    """
Kai Zhang's avatar
Kai Zhang committed
184
185
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(encode_jpeg)
186
    if quality < 1 or quality > 100:
187
        raise ValueError("Image quality should be a positive number between 1 and 100")
188
189
190
191
192
193
194

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
195
    Takes an input tensor in CHW layout and saves it in a JPEG file.
196

197
    Args:
198
199
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
200
201
202
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
203
    """
Kai Zhang's avatar
Kai Zhang committed
204
205
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_jpeg)
206
207
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
208
209


210
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
211
212
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
213
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
214

215
    Optionally converts the image to the desired format.
216
    The values of the output tensor are uint8 in [0, 255].
Francisco Massa's avatar
Francisco Massa committed
217

218
219
220
221
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
222
223
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
224
225
226
227
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
228
    """
Kai Zhang's avatar
Kai Zhang committed
229
230
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(decode_image)
231
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
232
233
234
    return output


235
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
236
    """
237
    Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
238
    Optionally converts the image to the desired format.
239
    The values of the output tensor are uint8 in [0, 255].
240

241
242
243
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
244
245
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
246
247
248
249
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
250
    """
Kai Zhang's avatar
Kai Zhang committed
251
252
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_image)
Francisco Massa's avatar
Francisco Massa committed
253
    data = read_file(path)
254
    return decode_image(data, mode)
255
256
257
258
259


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
    data = read_file(path)
    return torch.ops.image.decode_png(data, mode.value, True)