image.py 5.51 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6
7
8
9
10
11
12
13
14
15
16

_HAS_IMAGE_OPT = False

try:
    lib_dir = osp.join(osp.dirname(__file__), "..")

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

17
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
18
19
20
21
22
23
24
25
    ext_specs = extfinder.find_spec("image")
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


Francisco Massa's avatar
Francisco Massa committed
26
27
28
29
30
31
32
33
34
35
36
37
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

    Arguments:
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
38
39
40
    return data


41
def decode_png(input: torch.Tensor) -> torch.Tensor:
42
43
44
45
46
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
Francisco Massa's avatar
Francisco Massa committed
47
        input (Tensor[1]): a one dimensional uint8 tensor containing
48
49
50
    the raw bytes of the PNG image.

    Returns:
51
        output (Tensor[3, image_height, image_width])
52
53
54
55
56
    """
    output = torch.ops.image.decode_png(input)
    return output


57
58
59
60
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
61
62
63
64
65
66
67
68
69

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must 3 or 1.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6

70
    Returns
71
72
73
74
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        PNG file.
75
76
77
78
79
80
81
82
83
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
84
85
86
87
88
89
90
91
92
93

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    compression_level: int
        Compression factor for the resulting file, it must be a number
        between 0 and 9. Default: 6
94
95
96
97
    """
    torch.ops.image.write_png(input, filename, compression_level)


98
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
99
100
101
102
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
Francisco Massa's avatar
Francisco Massa committed
103
        input (Tensor[1]): a one dimensional uint8 tensor containing
104
105
    the raw bytes of the JPEG image.
    Returns:
106
        output (Tensor[3, image_height, image_width])
107
108
109
110
111
    """
    output = torch.ops.image.decode_jpeg(input)
    return output


112
113
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
114
115
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
116
117
118
119
120
121
122
123
124

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width])
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    quality: int
        Quality of the resulting JPEG file, it must be a number between
        1 and 100. Default: 75

125
    Returns
126
127
128
129
    -------
    output: Tensor[1]
        A one dimensional int8 tensor that contains the raw bytes of the
        JPEG file.
130
131
132
133
134
135
136
137
138
139
140
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
141
    Takes an input tensor in CHW layout and saves it in a JPEG file.
142
143
144
145
146
147
148
149
150
151

    Parameters
    ----------
    input: Tensor[channels, image_height, image_width]
        int8 image tensor of `c` channels, where `c` must be 1 or 3.
    filename: str
        Path to save the image.
    quality: int
        Quality of the resulting JPEG file, it must be a number
        between 1 and 100. Default: 75
152
153
154
155
156
157
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    torch.ops.image.write_jpeg(input, filename, quality)
Francisco Massa's avatar
Francisco Massa committed
158
159
160
161
162
163
164
165
166


def decode_image(input: torch.Tensor) -> torch.Tensor:
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

    The values of the output tensor are uint8 between 0 and 255.

167
168
169
170
171
172
173
174
175
    Parameters
    ----------
    input: Tensor
        a one dimensional uint8 tensor containing the raw bytes of the
        PNG or JPEG image.

    Returns
    -------
    output: Tensor[3, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
176
177
178
179
180
181
182
183
184
    """
    output = torch.ops.image.decode_image(input)
    return output


def read_image(path: str) -> torch.Tensor:
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
185
186
187
188
189
190
191
192
193

    Parameters
    ----------
    path: str
        path of the JPEG or PNG image.

    Returns
    -------
    output: Tensor[3, image_height, image_width]
Francisco Massa's avatar
Francisco Massa committed
194
    """
Francisco Massa's avatar
Francisco Massa committed
195
    data = read_file(path)
Francisco Massa's avatar
Francisco Massa committed
196
    return decode_image(data)