image.py 6.61 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6
7
8
9

_HAS_IMAGE_OPT = False

try:
10
    lib_dir = osp.abspath(osp.join(osp.dirname(__file__), ".."))
11
12
13
14
15
16

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

17
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
18
    ext_specs = extfinder.find_spec("image")
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    if os.name == 'nt':
        # Load the image extension using LoadLibraryExW
        import ctypes
        import sys

        kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
        with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
        prev_error_mode = kernel32.SetErrorMode(0x0001)

        kernel32.LoadLibraryW.restype = ctypes.c_void_p
        if with_load_library_flags:
            kernel32.LoadLibraryExW.restype = ctypes.c_void_p

        if ext_specs is not None:
            res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100)
            if res is None:
                err = ctypes.WinError(ctypes.get_last_error())
                err.strerror += (f' Error loading "{ext_specs.origin}" or any or '
                                 'its dependencies.')
                raise err

        kernel32.SetErrorMode(prev_error_mode)

43
44
45
46
47
48
49
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


Francisco Massa's avatar
Francisco Massa committed
50
51
52
53
54
55
56
57
58
59
60
61
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

    Arguments:
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
62
63
64
    return data


Francisco Massa's avatar
Francisco Massa committed
65
66
67
68
69
70
71
72
73
74
75
76
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

    Arguments:
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


77
def decode_png(input: torch.Tensor) -> torch.Tensor:
78
79
80
81
82
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
Francisco Massa's avatar
Francisco Massa committed
83
        input (Tensor[1]): a one dimensional uint8 tensor containing
84
85
86
    the raw bytes of the PNG image.

    Returns:
87
        output (Tensor[3, image_height, image_width])
88
89
90
91
92
    """
    output = torch.ops.image.decode_png(input)
    return output


93
94
95
96
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
97
98
99
100
101
102
103
104
105

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must 3 or 1.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6

106
    Returns
107
108
109
110
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        PNG file.
111
112
113
114
115
116
117
118
119
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
120
121
122
123
124
125
126
127
128
129

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6
130
    """
131
132
    output = encode_png(input, compression_level)
    write_file(filename, output)
133
134


135
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
136
137
138
139
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
Francisco Massa's avatar
Francisco Massa committed
140
        input (Tensor[1]): a one dimensional uint8 tensor containing
141
142
    the raw bytes of the JPEG image.
    Returns:
143
        output (Tensor[3, image_height, image_width])
144
145
146
147
148
    """
    output = torch.ops.image.decode_jpeg(input)
    return output


149
150
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
151
152
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
153
154
155
156
157
158
159
160
161

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width])
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    quality: int
        Quality of the resulting JPEG file, it must be a number between
        1 and 100. Default: 75

162
    Returns
163
164
165
166
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        JPEG file.
167
168
169
170
171
172
173
174
175
176
177
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
178
    Takes an input tensor in CHW layout and saves it in a JPEG file.
179
180
181
182
183
184
185
186
187
188

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    quality: int
        Quality of the resulting JPEG file, it must be a number
        between 1 and 100. Default: 75
189
    """
190
191
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
192
193
194
195
196
197
198
199
200


def decode_image(input: torch.Tensor) -> torch.Tensor:
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.

201
202
203
204
205
206
207
208
209
    Parameters
    ----------
    input: Tensor
        a one dimensional uint8 tensor containing the raw bytes of the
        PNG or JPEG image.

    Returns
    -------
    output: Tensor[3, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
210
211
212
213
214
215
216
217
218
    """
    output = torch.ops.image.decode_image(input)
    return output


def read_image(path: str) -> torch.Tensor:
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
219
220
221
222
223
224
225
226
227

    Parameters
    ----------
    path: str
        path of the JPEG or PNG image.

    Returns
    -------
    output: Tensor[3, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
228
    """
Francisco Massa's avatar
Francisco Massa committed
229
    data = read_file(path)
Francisco Massa's avatar
Francisco Massa committed
230
    return decode_image(data)