image.py 8.57 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6

7
8
from enum import Enum

9
10
11
_HAS_IMAGE_OPT = False

try:
12
    lib_dir = osp.abspath(osp.join(osp.dirname(__file__), ".."))
13
14
15
16
17
18

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

19
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
20
    ext_specs = extfinder.find_spec("image")
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    if os.name == 'nt':
        # Load the image extension using LoadLibraryExW
        import ctypes
        import sys

        kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
        with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
        prev_error_mode = kernel32.SetErrorMode(0x0001)

        kernel32.LoadLibraryW.restype = ctypes.c_void_p
        if with_load_library_flags:
            kernel32.LoadLibraryExW.restype = ctypes.c_void_p

        if ext_specs is not None:
            res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100)
            if res is None:
                err = ctypes.WinError(ctypes.get_last_error())
                err.strerror += (f' Error loading "{ext_specs.origin}" or any or '
                                 'its dependencies.')
                raise err

        kernel32.SetErrorMode(prev_error_mode)

45
46
47
48
49
50
51
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


52
53
54
55
56
57
58
59
class ImageReadMode(Enum):
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
60
61
62
63
64
65
66
67
68
69
70
71
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

    Arguments:
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
72
73
74
    return data


Francisco Massa's avatar
Francisco Massa committed
75
76
77
78
79
80
81
82
83
84
85
86
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

    Arguments:
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


87
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
88
89
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
90
    Optionally converts the image to the desired format.
91
92
93
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
Francisco Massa's avatar
Francisco Massa committed
94
        input (Tensor[1]): a one dimensional uint8 tensor containing
95
    the raw bytes of the PNG image.
96
97
98
99
100
101
        mode (ImageReadMode): the read mode used for optionally
    converting the image. Use `ImageReadMode.UNCHANGED` for loading
    the image as-is, `ImageReadMode.GRAY` for converting to grayscale,
    `ImageReadMode.GRAY_ALPHA` for grayscale with transparency,
    `ImageReadMode.RGB` for RGB and `ImageReadMode.RGB_ALPHA` for
     RGB with transparency. Default: `ImageReadMode.UNCHANGED`
102
103

    Returns:
104
        output (Tensor[image_channels, image_height, image_width])
105
    """
106
    output = torch.ops.image.decode_png(input, mode.value)
107
108
109
    return output


110
111
112
113
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
114
115
116
117
118
119
120
121
122

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must 3 or 1.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6

123
    Returns
124
125
126
127
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        PNG file.
128
129
130
131
132
133
134
135
136
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
137
138
139
140
141
142
143
144
145
146

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6
147
    """
148
149
    output = encode_png(input, compression_level)
    write_file(filename, output)
150
151


152
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
153
154
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
155
    Optionally converts the image to the desired format.
156
    The values of the output tensor are uint8 between 0 and 255.
157

158
    Arguments:
Francisco Massa's avatar
Francisco Massa committed
159
        input (Tensor[1]): a one dimensional uint8 tensor containing
160
    the raw bytes of the JPEG image.
161
162
163
164
        mode (ImageReadMode): the read mode used for optionally
    converting the image. Use `ImageReadMode.UNCHANGED` for loading
    the image as-is, `ImageReadMode.GRAY` for converting to grayscale
    and `ImageReadMode.RGB` for RGB. Default: `ImageReadMode.UNCHANGED`
165

166
    Returns:
167
        output (Tensor[image_channels, image_height, image_width])
168
    """
169
    output = torch.ops.image.decode_jpeg(input, mode.value)
170
171
172
    return output


173
174
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
175
176
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
177
178
179
180
181
182
183
184
185

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width])
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    quality: int
        Quality of the resulting JPEG file, it must be a number between
        1 and 100. Default: 75

186
    Returns
187
188
189
190
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        JPEG file.
191
192
193
194
195
196
197
198
199
200
201
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
202
    Takes an input tensor in CHW layout and saves it in a JPEG file.
203
204
205
206
207
208
209
210
211
212

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    quality: int
        Quality of the resulting JPEG file, it must be a number
        between 1 and 100. Default: 75
213
    """
214
215
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
216
217


218
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
219
220
221
222
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

223
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
224
225
    The values of the output tensor are uint8 between 0 and 255.

226
227
228
229
230
    Parameters
    ----------
    input: Tensor
        a one dimensional uint8 tensor containing the raw bytes of the
        PNG or JPEG image.
231
232
233
234
235
236
    mode: ImageReadMode
        the read mode used for optionally converting the image. JPEG
        and PNG images have different permitted values. The default
        value is `ImageReadMode.UNCHANGED` and it keeps the image as-is.
        See `decode_jpeg()` and `decode_png()` for more information.
        Default: `ImageReadMode.UNCHANGED`
237
238
239

    Returns
    -------
240
    output: Tensor[image_channels, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
241
    """
242
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
243
244
245
    return output


246
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
247
248
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
249
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
250
    The values of the output tensor are uint8 between 0 and 255.
251
252
253
254
255

    Parameters
    ----------
    path: str
        path of the JPEG or PNG image.
256
257
258
259
260
261
    mode: ImageReadMode
        the read mode used for optionally converting the image. JPEG
        and PNG images have different permitted values. The default
        value is `ImageReadMode.UNCHANGED` and it keeps the image as-is.
        See `decode_jpeg()` and `decode_png()` for more information.
        Default: `ImageReadMode.UNCHANGED`
262
263
264

    Returns
    -------
265
    output: Tensor[image_channels, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
266
    """
Francisco Massa's avatar
Francisco Massa committed
267
    data = read_file(path)
268
    return decode_image(data, mode)