image.py 6.04 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6
7
8
9
10
11
12
13
14
15
16

_HAS_IMAGE_OPT = False

try:
    lib_dir = osp.join(osp.dirname(__file__), "..")

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

17
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
18
19
20
21
22
23
24
25
    ext_specs = extfinder.find_spec("image")
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


Francisco Massa's avatar
Francisco Massa committed
26
27
28
29
30
31
32
33
34
35
36
37
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

    Arguments:
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
38
39
40
    return data


41
def decode_png(input: torch.Tensor) -> torch.Tensor:
42
43
44
45
46
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
Francisco Massa's avatar
Francisco Massa committed
47
        input (Tensor[1]): a one dimensional uint8 tensor containing
48
49
50
    the raw bytes of the PNG image.

    Returns:
51
        output (Tensor[3, image_height, image_width])
52
53
54
55
56
    """
    output = torch.ops.image.decode_png(input)
    return output


57
def read_png(path: str) -> torch.Tensor:
58
59
60
61
62
63
64
65
    """
    Reads a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
        path (str): path of the PNG image.

    Returns:
66
        output (Tensor[3, image_height, image_width])
67
    """
Francisco Massa's avatar
Francisco Massa committed
68
    data = read_file(path)
69
    return decode_png(data)
70
71


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
    Arguments:
        input (Tensor[channels, image_height, image_width]): int8 image tensor
    of `c` channels, where `c` must 3 or 1.
        compression_level (int): Compression factor for the resulting file, it
    must be a number between 0 and 9. Default: 6
    Returns
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw
    bytes of the PNG file.
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
    Arguments:
        input (Tensor[channels, image_height, image_width]): int8 image tensor
    of `c` channels, where `c` must be 1 or 3.
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it
    must be a number between 0 and 9. Default: 6
    """
    torch.ops.image.write_png(input, filename, compression_level)


103
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
104
105
106
107
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
Francisco Massa's avatar
Francisco Massa committed
108
        input (Tensor[1]): a one dimensional uint8 tensor containing
109
110
    the raw bytes of the JPEG image.
    Returns:
111
        output (Tensor[3, image_height, image_width])
112
113
114
115
116
    """
    output = torch.ops.image.decode_jpeg(input)
    return output


117
def read_jpeg(path: str) -> torch.Tensor:
118
119
120
121
122
123
    """
    Reads a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
        path (str): path of the JPEG image.
    Returns:
124
        output (Tensor[3, image_height, image_width])
125
    """
Francisco Massa's avatar
Francisco Massa committed
126
    data = read_file(path)
127
    return decode_jpeg(data)
128
129
130
131


def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
132
133
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    Arguments:
        input (Tensor[channels, image_height, image_width]): int8 image tensor
    of `c` channels, where `c` must be 1 or 3.
        quality (int): Quality of the resulting JPEG file, it must be a number
    between 1 and 100. Default: 75
    Returns
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw
    bytes of the JPEG file.
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
153
    Takes an input tensor in CHW layout and saves it in a JPEG file.
154
155
156
157
158
159
160
161
162
163
164
165
    Arguments:
        input (Tensor[channels, image_height, image_width]): int8 image tensor
    of `c` channels, where `c` must be 1 or 3.
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
    between 1 and 100. Default: 75
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    torch.ops.image.write_jpeg(input, filename, quality)
Francisco Massa's avatar
Francisco Massa committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193


def decode_image(input: torch.Tensor) -> torch.Tensor:
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
        input (Tensor): a one dimensional uint8 tensor containing
        the raw bytes of the PNG or JPEG image.
    Returns:
        output (Tensor[3, image_height, image_width])
    """
    output = torch.ops.image.decode_image(input)
    return output


def read_image(path: str) -> torch.Tensor:
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
        path (str): path of the JPEG or PNG image.
    Returns:
        output (Tensor[3, image_height, image_width])
    """
Francisco Massa's avatar
Francisco Massa committed
194
    data = read_file(path)
Francisco Massa's avatar
Francisco Massa committed
195
    return decode_image(data)