image.py 8.89 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6

7
8
from enum import Enum

9
10
11
_HAS_IMAGE_OPT = False

try:
12
    lib_dir = osp.abspath(osp.join(osp.dirname(__file__), ".."))
13
14
15
16
17
18

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

19
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
20
    ext_specs = extfinder.find_spec("image")
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    if os.name == 'nt':
        # Load the image extension using LoadLibraryExW
        import ctypes

        kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
        with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
        prev_error_mode = kernel32.SetErrorMode(0x0001)

        kernel32.LoadLibraryW.restype = ctypes.c_void_p
        if with_load_library_flags:
            kernel32.LoadLibraryExW.restype = ctypes.c_void_p

        if ext_specs is not None:
            res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100)
            if res is None:
                err = ctypes.WinError(ctypes.get_last_error())
                err.strerror += (f' Error loading "{ext_specs.origin}" or any or '
                                 'its dependencies.')
                raise err

        kernel32.SetErrorMode(prev_error_mode)

44
45
46
47
48
49
50
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


51
class ImageReadMode(Enum):
52
53
54
55
56
57
58
59
60
    """
    Support for various modes while reading images.

    Use `ImageReadMode.UNCHANGED` for loading the image as-is,
    `ImageReadMode.GRAY` for converting to grayscale,
    `ImageReadMode.GRAY_ALPHA` for grayscale with transparency,
    `ImageReadMode.RGB` for RGB and `ImageReadMode.RGB_ALPHA` for
    RGB with transparency.
    """
61
62
63
64
65
66
67
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
68
69
70
71
72
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

73
    Args:
Francisco Massa's avatar
Francisco Massa committed
74
75
76
77
78
79
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
80
81
82
    return data


Francisco Massa's avatar
Francisco Massa committed
83
84
85
86
87
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

88
    Args:
Francisco Massa's avatar
Francisco Massa committed
89
90
91
92
93
94
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


95
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
96
97
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
98
    Optionally converts the image to the desired format.
99
100
    The values of the output tensor are uint8 between 0 and 255.

101
    Args:
Francisco Massa's avatar
Francisco Massa committed
102
        input (Tensor[1]): a one dimensional uint8 tensor containing
103
            the raw bytes of the PNG image.
104
        mode (ImageReadMode): the read mode used for optionally
105
106
107
            converting the image. Default: `ImageReadMode.UNCHANGED`.
            See `ImageReadMode` class for more information on various
            available modes.
108
109

    Returns:
110
        output (Tensor[image_channels, image_height, image_width])
111
    """
112
    output = torch.ops.image.decode_png(input, mode.value)
113
114
115
    return output


116
117
118
119
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
120

121
122
123
124
125
126
127
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
            `c` channels, where `c` must 3 or 1.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
128
129
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
130
131
132
133
134
135
136
137
138
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
139

140
141
142
143
144
145
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
            `c` channels, where `c` must be 1 or 3.
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
146
    """
147
148
    output = encode_png(input, compression_level)
    write_file(filename, output)
149
150


151
152
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
                device: str = 'cpu') -> torch.Tensor:
153
154
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
155
    Optionally converts the image to the desired format.
156
    The values of the output tensor are uint8 between 0 and 255.
157

158
    Args:
Francisco Massa's avatar
Francisco Massa committed
159
        input (Tensor[1]): a one dimensional uint8 tensor containing
160
161
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
162
        mode (ImageReadMode): the read mode used for optionally
163
164
165
            converting the image. Default: `ImageReadMode.UNCHANGED`.
            See `ImageReadMode` class for more information on various
            available modes.
166
167
168
169
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
170

171
    Returns:
172
        output (Tensor[image_channels, image_height, image_width])
173
    """
174
175
176
177
178
    device = torch.device(device)
    if device.type == 'cuda':
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)
179
180
181
    return output


182
183
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
184
185
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
186

187
188
189
190
191
192
193
194
195
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
            `c` channels, where `c` must be 1 or 3.
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
196
197
198
199
200
201
202
203
204
205
206
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
207
    Takes an input tensor in CHW layout and saves it in a JPEG file.
208

209
210
211
212
213
214
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of `c`
            channels, where `c` must be 1 or 3.
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
215
    """
216
217
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
218
219


220
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
221
222
223
224
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

225
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
226
227
    The values of the output tensor are uint8 between 0 and 255.

228
229
230
231
232
233
234
235
236
237
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
            Default: `ImageReadMode.UNCHANGED`.
            See `ImageReadMode` class for more information on various
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
238
    """
239
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
240
241
242
    return output


243
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
244
245
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
246
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
247
    The values of the output tensor are uint8 between 0 and 255.
248

249
250
251
252
253
254
255
256
257
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
            Default: `ImageReadMode.UNCHANGED`.
            See `ImageReadMode` class for more information on various
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
258
    """
Francisco Massa's avatar
Francisco Massa committed
259
    data = read_file(path)
260
    return decode_image(data, mode)