image.py 7.89 KB
Newer Older
1
2
from enum import Enum

3
4
import torch

5
from .._internally_replaced_utils import _get_extension_path
6
7


8
try:
9
    lib_path = _get_extension_path("image")
10
    torch.ops.load_library(lib_path)
11
12
13
14
except (ImportError, OSError):
    pass


15
class ImageReadMode(Enum):
16
17
18
    """
    Support for various modes while reading images.

19
20
21
22
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
23
24
    RGB with transparency.
    """
25

26
27
28
29
30
31
32
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
33
34
35
36
37
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

38
    Args:
Francisco Massa's avatar
Francisco Massa committed
39
40
41
42
43
44
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
45
46
47
    return data


Francisco Massa's avatar
Francisco Massa committed
48
49
50
51
52
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

53
    Args:
Francisco Massa's avatar
Francisco Massa committed
54
55
56
57
58
59
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


60
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
61
    """
62
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
63
    Optionally converts the image to the desired format.
64
    The values of the output tensor are uint8 in [0, 255].
65

66
    Args:
Francisco Massa's avatar
Francisco Massa committed
67
        input (Tensor[1]): a one dimensional uint8 tensor containing
68
            the raw bytes of the PNG image.
69
        mode (ImageReadMode): the read mode used for optionally
70
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
71
72
            See `ImageReadMode` class for more information on various
            available modes.
73
74

    Returns:
75
        output (Tensor[image_channels, image_height, image_width])
76
    """
77
    output = torch.ops.image.decode_png(input, mode.value, False)
78
79
80
    return output


81
82
83
84
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
85

86
87
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
88
            ``c`` channels, where ``c`` must 3 or 1.
89
90
91
92
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
93
94
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
95
96
97
98
99
100
101
102
103
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
104

105
106
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
107
            ``c`` channels, where ``c`` must be 1 or 3.
108
109
110
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
111
    """
112
113
    output = encode_png(input, compression_level)
    write_file(filename, output)
114
115


116
117
118
def decode_jpeg(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
119
    """
120
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
121
    Optionally converts the image to the desired format.
122
    The values of the output tensor are uint8 between 0 and 255.
123

124
    Args:
Francisco Massa's avatar
Francisco Massa committed
125
        input (Tensor[1]): a one dimensional uint8 tensor containing
126
127
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
128
        mode (ImageReadMode): the read mode used for optionally
129
130
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
131
            available modes.
132
133
134
135
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
136

137
    Returns:
138
        output (Tensor[image_channels, image_height, image_width])
139
    """
140
    device = torch.device(device)
141
    if device.type == "cuda":
142
143
144
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)
145
146
147
    return output


148
149
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
150
151
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
152

153
154
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
155
            ``c`` channels, where ``c`` must be 1 or 3.
156
157
158
159
160
161
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
162
163
    """
    if quality < 1 or quality > 100:
164
        raise ValueError("Image quality should be a positive number " "between 1 and 100")
165
166
167
168
169
170
171

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
172
    Takes an input tensor in CHW layout and saves it in a JPEG file.
173

174
    Args:
175
176
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
177
178
179
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
180
    """
181
182
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
183
184


185
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
186
187
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
188
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
189

190
    Optionally converts the image to the desired format.
191
    The values of the output tensor are uint8 in [0, 255].
Francisco Massa's avatar
Francisco Massa committed
192

193
194
195
196
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
197
198
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
199
200
201
202
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
203
    """
204
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
205
206
207
    return output


208
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
209
    """
210
    Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
211
    Optionally converts the image to the desired format.
212
    The values of the output tensor are uint8 in [0, 255].
213

214
215
216
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
217
218
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
219
220
221
222
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
223
    """
Francisco Massa's avatar
Francisco Massa committed
224
    data = read_file(path)
225
    return decode_image(data, mode)
226
227
228
229
230


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
    data = read_file(path)
    return torch.ops.image.decode_png(data, mode.value, True)