image.py 8.07 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6
7
8
9

_HAS_IMAGE_OPT = False

try:
10
    lib_dir = osp.abspath(osp.join(osp.dirname(__file__), ".."))
11
12
13
14
15
16

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

17
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
18
    ext_specs = extfinder.find_spec("image")
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    if os.name == 'nt':
        # Load the image extension using LoadLibraryExW
        import ctypes
        import sys

        kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
        with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
        prev_error_mode = kernel32.SetErrorMode(0x0001)

        kernel32.LoadLibraryW.restype = ctypes.c_void_p
        if with_load_library_flags:
            kernel32.LoadLibraryExW.restype = ctypes.c_void_p

        if ext_specs is not None:
            res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100)
            if res is None:
                err = ctypes.WinError(ctypes.get_last_error())
                err.strerror += (f' Error loading "{ext_specs.origin}" or any or '
                                 'its dependencies.')
                raise err

        kernel32.SetErrorMode(prev_error_mode)

43
44
45
46
47
48
49
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


Francisco Massa's avatar
Francisco Massa committed
50
51
52
53
54
55
56
57
58
59
60
61
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

    Arguments:
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
62
63
64
    return data


Francisco Massa's avatar
Francisco Massa committed
65
66
67
68
69
70
71
72
73
74
75
76
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

    Arguments:
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


77
def decode_png(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
78
79
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
80
    Optionally converts the image to the desired number of color channels.
81
82
83
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
Francisco Massa's avatar
Francisco Massa committed
84
        input (Tensor[1]): a one dimensional uint8 tensor containing
85
    the raw bytes of the PNG image.
86
87
88
89
        channels (int): the number of output channels for the decoded
    image. 0 keeps the original number of channels, 1 converts to Grayscale
    2 converts to Grayscale with Alpha, 3 converts to RGB and 4 coverts to
    RGB with Alpha. Default: 0
90
91

    Returns:
92
        output (Tensor[image_channels, image_height, image_width])
93
    """
94
    output = torch.ops.image.decode_png(input, channels)
95
96
97
    return output


98
99
100
101
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
102
103
104
105
106
107
108
109
110

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must 3 or 1.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6

111
    Returns
112
113
114
115
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        PNG file.
116
117
118
119
120
121
122
123
124
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
125
126
127
128
129
130
131
132
133
134

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6
135
    """
136
137
    output = encode_png(input, compression_level)
    write_file(filename, output)
138
139


140
def decode_jpeg(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
141
142
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
143
    Optionally converts the image to the desired number of color channels.
144
    The values of the output tensor are uint8 between 0 and 255.
145

146
    Arguments:
Francisco Massa's avatar
Francisco Massa committed
147
        input (Tensor[1]): a one dimensional uint8 tensor containing
148
    the raw bytes of the JPEG image.
149
150
151
152
        channels (int): the number of output channels for the decoded
    image. 0 keeps the original number of channels, 1 converts to Grayscale
    and 3 converts to RGB. Default: 0

153
    Returns:
154
        output (Tensor[image_channels, image_height, image_width])
155
    """
156
    output = torch.ops.image.decode_jpeg(input, channels)
157
158
159
    return output


160
161
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
162
163
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
164
165
166
167
168
169
170
171
172

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width])
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    quality: int
        Quality of the resulting JPEG file, it must be a number between
        1 and 100. Default: 75

173
    Returns
174
175
176
177
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        JPEG file.
178
179
180
181
182
183
184
185
186
187
188
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
189
    Takes an input tensor in CHW layout and saves it in a JPEG file.
190
191
192
193
194
195
196
197
198
199

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    quality: int
        Quality of the resulting JPEG file, it must be a number
        between 1 and 100. Default: 75
200
    """
201
202
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
203
204


205
def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
206
207
208
209
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

210
    Optionally converts the image to the desired number of color channels.
Francisco Massa's avatar
Francisco Massa committed
211
212
    The values of the output tensor are uint8 between 0 and 255.

213
214
215
216
217
    Parameters
    ----------
    input: Tensor
        a one dimensional uint8 tensor containing the raw bytes of the
        PNG or JPEG image.
218
219
220
221
222
    channels: int
        the number of output channels of the decoded image. JPEG and PNG images
        have different permitted values. The default value is 0 and it keeps
        the original number of channels. See `decode_jpeg()` and `decode_png()`
        for more information. Default: 0
223
224
225

    Returns
    -------
226
    output: Tensor[image_channels, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
227
    """
228
    output = torch.ops.image.decode_image(input, channels)
Francisco Massa's avatar
Francisco Massa committed
229
230
231
    return output


232
def read_image(path: str, channels: int = 0) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
233
234
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
235
    Optionally converts the image to the desired number of color channels.
Francisco Massa's avatar
Francisco Massa committed
236
    The values of the output tensor are uint8 between 0 and 255.
237
238
239
240
241

    Parameters
    ----------
    path: str
        path of the JPEG or PNG image.
242
243
244
245
246
    channels: int
        the number of output channels of the decoded image. JPEG and PNG images
        have different permitted values. The default value is 0 and it keeps
        the original number of channels. See `decode_jpeg()` and `decode_png()`
        for more information. Default: 0
247
248
249

    Returns
    -------
250
    output: Tensor[image_channels, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
251
    """
Francisco Massa's avatar
Francisco Massa committed
252
    data = read_file(path)
253
    return decode_image(data, channels)