image.py 5.84 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6
7
8
9
10
11
12
13
14
15
16

_HAS_IMAGE_OPT = False

try:
    lib_dir = osp.join(osp.dirname(__file__), "..")

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

17
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
18
19
20
21
22
23
24
25
    ext_specs = extfinder.find_spec("image")
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


Francisco Massa's avatar
Francisco Massa committed
26
27
28
29
30
31
32
33
34
35
36
37
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

    Arguments:
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
38
39
40
    return data


Francisco Massa's avatar
Francisco Massa committed
41
42
43
44
45
46
47
48
49
50
51
52
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

    Arguments:
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


53
def decode_png(input: torch.Tensor) -> torch.Tensor:
54
55
56
57
58
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
Francisco Massa's avatar
Francisco Massa committed
59
        input (Tensor[1]): a one dimensional uint8 tensor containing
60
61
62
    the raw bytes of the PNG image.

    Returns:
63
        output (Tensor[3, image_height, image_width])
64
65
66
67
68
    """
    output = torch.ops.image.decode_png(input)
    return output


69
70
71
72
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
73
74
75
76
77
78
79
80
81

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must 3 or 1.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6

82
    Returns
83
84
85
86
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        PNG file.
87
88
89
90
91
92
93
94
95
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
96
97
98
99
100
101
102
103
104
105

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6
106
107
108
109
    """
    torch.ops.image.write_png(input, filename, compression_level)


110
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
111
112
113
114
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
Francisco Massa's avatar
Francisco Massa committed
115
        input (Tensor[1]): a one dimensional uint8 tensor containing
116
117
    the raw bytes of the JPEG image.
    Returns:
118
        output (Tensor[3, image_height, image_width])
119
120
121
122
123
    """
    output = torch.ops.image.decode_jpeg(input)
    return output


124
125
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
126
127
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
128
129
130
131
132
133
134
135
136

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width])
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    quality: int
        Quality of the resulting JPEG file, it must be a number between
        1 and 100. Default: 75

137
    Returns
138
139
140
141
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        JPEG file.
142
143
144
145
146
147
148
149
150
151
152
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
153
    Takes an input tensor in CHW layout and saves it in a JPEG file.
154
155
156
157
158
159
160
161
162
163

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    quality: int
        Quality of the resulting JPEG file, it must be a number
        between 1 and 100. Default: 75
164
165
166
167
168
169
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    torch.ops.image.write_jpeg(input, filename, quality)
Francisco Massa's avatar
Francisco Massa committed
170
171
172
173
174
175
176
177
178


def decode_image(input: torch.Tensor) -> torch.Tensor:
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.

179
180
181
182
183
184
185
186
187
    Parameters
    ----------
    input: Tensor
        a one dimensional uint8 tensor containing the raw bytes of the
        PNG or JPEG image.

    Returns
    -------
    output: Tensor[3, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
188
189
190
191
192
193
194
195
196
    """
    output = torch.ops.image.decode_image(input)
    return output


def read_image(path: str) -> torch.Tensor:
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
197
198
199
200
201
202
203
204
205

    Parameters
    ----------
    path: str
        path of the JPEG or PNG image.

    Returns
    -------
    output: Tensor[3, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
206
    """
Francisco Massa's avatar
Francisco Massa committed
207
    data = read_file(path)
Francisco Massa's avatar
Francisco Massa committed
208
    return decode_image(data)