image.py 7.67 KB
Newer Older
1
2
from enum import Enum

3
4
import torch

5
from .._internally_replaced_utils import _get_extension_path
6
7


8
try:
9
    lib_path = _get_extension_path("image")
10
    torch.ops.load_library(lib_path)
11
12
13
14
except (ImportError, OSError):
    pass


15
class ImageReadMode(Enum):
16
17
18
    """
    Support for various modes while reading images.

19
20
21
22
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
23
24
    RGB with transparency.
    """
25

26
27
28
29
30
31
32
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
33
34
35
36
37
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

38
    Args:
Francisco Massa's avatar
Francisco Massa committed
39
40
41
42
43
44
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
45
46
47
    return data


Francisco Massa's avatar
Francisco Massa committed
48
49
50
51
52
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

53
    Args:
Francisco Massa's avatar
Francisco Massa committed
54
55
56
57
58
59
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


60
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
61
62
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
63
    Optionally converts the image to the desired format.
64
65
    The values of the output tensor are uint8 between 0 and 255.

66
    Args:
Francisco Massa's avatar
Francisco Massa committed
67
        input (Tensor[1]): a one dimensional uint8 tensor containing
68
            the raw bytes of the PNG image.
69
        mode (ImageReadMode): the read mode used for optionally
70
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
71
72
            See `ImageReadMode` class for more information on various
            available modes.
73
74

    Returns:
75
        output (Tensor[image_channels, image_height, image_width])
76
    """
77
    output = torch.ops.image.decode_png(input, mode.value)
78
79
80
    return output


81
82
83
84
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
85

86
87
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
88
            ``c`` channels, where ``c`` must 3 or 1.
89
90
91
92
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
93
94
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
95
96
97
98
99
100
101
102
103
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
104

105
106
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
107
            ``c`` channels, where ``c`` must be 1 or 3.
108
109
110
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
111
    """
112
113
    output = encode_png(input, compression_level)
    write_file(filename, output)
114
115


116
117
118
def decode_jpeg(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
119
120
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
121
    Optionally converts the image to the desired format.
122
    The values of the output tensor are uint8 between 0 and 255.
123

124
    Args:
Francisco Massa's avatar
Francisco Massa committed
125
        input (Tensor[1]): a one dimensional uint8 tensor containing
126
127
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
128
        mode (ImageReadMode): the read mode used for optionally
129
130
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
131
            available modes.
132
133
134
135
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
136

137
    Returns:
138
        output (Tensor[image_channels, image_height, image_width])
139
    """
140
    device = torch.device(device)
141
    if device.type == "cuda":
142
143
144
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)
145
146
147
    return output


148
149
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
150
151
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
152

153
154
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
155
            ``c`` channels, where ``c`` must be 1 or 3.
156
157
158
159
160
161
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
162
163
    """
    if quality < 1 or quality > 100:
164
        raise ValueError("Image quality should be a positive number " "between 1 and 100")
165
166
167
168
169
170
171

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
172
    Takes an input tensor in CHW layout and saves it in a JPEG file.
173

174
    Args:
175
176
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
177
178
179
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
180
    """
181
182
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
183
184


185
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
186
187
188
189
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
    operation to decode the image into a 3 dimensional RGB Tensor.

190
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
191
192
    The values of the output tensor are uint8 between 0 and 255.

193
194
195
196
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
197
198
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
199
200
201
202
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
203
    """
204
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
205
206
207
    return output


208
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
209
210
    """
    Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
211
    Optionally converts the image to the desired format.
Francisco Massa's avatar
Francisco Massa committed
212
    The values of the output tensor are uint8 between 0 and 255.
213

214
215
216
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
217
218
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
219
220
221
222
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
223
    """
Francisco Massa's avatar
Francisco Massa committed
224
    data = read_file(path)
225
    return decode_image(data, mode)