"docs/zh_cn/tutorials/index.rst" did not exist on "dd74422e1455a6005c75ab73e08a11a8f4f22b92"
image.py 8.07 KB
Newer Older
1
2
from enum import Enum

3
4
import torch

5
from .._internally_replaced_utils import _get_extension_path
6
7


8
try:
9
    lib_path = _get_extension_path("image")
10
    torch.ops.load_library(lib_path)
11
12
13
14
except (ImportError, OSError):
    pass


15
class ImageReadMode(Enum):
16
17
18
    """
    Support for various modes while reading images.

19
20
21
22
    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
    ``ImageReadMode.GRAY`` for converting to grayscale,
    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
23
24
    RGB with transparency.
    """
25

26
27
28
29
30
31
32
    UNCHANGED = 0
    GRAY = 1
    GRAY_ALPHA = 2
    RGB = 3
    RGB_ALPHA = 4


Francisco Massa's avatar
Francisco Massa committed
33
34
35
36
37
def read_file(path: str) -> torch.Tensor:
    """
    Reads and outputs the bytes contents of a file as a uint8 Tensor
    with one dimension.

38
    Args:
Francisco Massa's avatar
Francisco Massa committed
39
40
41
42
43
44
        path (str): the path to the file to be read

    Returns:
        data (Tensor)
    """
    data = torch.ops.image.read_file(path)
Francisco Massa's avatar
Francisco Massa committed
45
46
47
    return data


Francisco Massa's avatar
Francisco Massa committed
48
49
50
51
52
def write_file(filename: str, data: torch.Tensor) -> None:
    """
    Writes the contents of a uint8 tensor with one dimension to a
    file.

53
    Args:
Francisco Massa's avatar
Francisco Massa committed
54
55
56
57
58
59
        filename (str): the path to the file to be written
        data (Tensor): the contents to be written to the output file
    """
    torch.ops.image.write_file(filename, data)


60
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
61
    """
62
    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
63
    Optionally converts the image to the desired format.
64
65
66
67
68
69
    The values of the output tensor are uint8 in [0, 255], except for
    16-bits pngs which are int32 tensors in [0, 65535].

    .. warning::
        Should pytorch ever support the uint16 dtype natively, the dtype of the
        output for 16-bits pngs will be updated from int32 to uint16.
70

71
    Args:
Francisco Massa's avatar
Francisco Massa committed
72
        input (Tensor[1]): a one dimensional uint8 tensor containing
73
            the raw bytes of the PNG image.
74
        mode (ImageReadMode): the read mode used for optionally
75
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
76
77
            See `ImageReadMode` class for more information on various
            available modes.
78
79

    Returns:
80
        output (Tensor[image_channels, image_height, image_width])
81
    """
82
    output = torch.ops.image.decode_png(input, mode.value)
83
84
85
    return output


86
87
88
89
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding PNG file.
90

91
92
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
93
            ``c`` channels, where ``c`` must 3 or 1.
94
95
96
97
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6

    Returns:
98
99
        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
            PNG file.
100
101
102
103
104
105
106
107
108
    """
    output = torch.ops.image.encode_png(input, compression_level)
    return output


def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a PNG file.
109

110
111
    Args:
        input (Tensor[channels, image_height, image_width]): int8 image tensor of
112
            ``c`` channels, where ``c`` must be 1 or 3.
113
114
115
        filename (str): Path to save the image.
        compression_level (int): Compression factor for the resulting file, it must be a number
            between 0 and 9. Default: 6
116
    """
117
118
    output = encode_png(input, compression_level)
    write_file(filename, output)
119
120


121
122
123
def decode_jpeg(
    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
124
    """
125
    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
126
    Optionally converts the image to the desired format.
127
    The values of the output tensor are uint8 between 0 and 255.
128

129
    Args:
Francisco Massa's avatar
Francisco Massa committed
130
        input (Tensor[1]): a one dimensional uint8 tensor containing
131
132
            the raw bytes of the JPEG image. This tensor must be on CPU,
            regardless of the ``device`` parameter.
133
        mode (ImageReadMode): the read mode used for optionally
134
135
            converting the image. Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
136
            available modes.
137
138
139
140
        device (str or torch.device): The device on which the decoded image will
            be stored. If a cuda device is specified, the image will be decoded
            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
            supported for CUDA version >= 10.1
141

142
    Returns:
143
        output (Tensor[image_channels, image_height, image_width])
144
    """
145
    device = torch.device(device)
146
    if device.type == "cuda":
147
148
149
        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
    else:
        output = torch.ops.image.decode_jpeg(input, mode.value)
150
151
152
    return output


153
154
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
155
156
    Takes an input tensor in CHW layout and returns a buffer with the contents
    of its corresponding JPEG file.
157

158
159
    Args:
        input (Tensor[channels, image_height, image_width])): int8 image tensor of
160
            ``c`` channels, where ``c`` must be 1 or 3.
161
162
163
164
165
166
        quality (int): Quality of the resulting JPEG file, it must be a number between
            1 and 100. Default: 75

    Returns:
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
            JPEG file.
167
168
    """
    if quality < 1 or quality > 100:
169
        raise ValueError("Image quality should be a positive number " "between 1 and 100")
170
171
172
173
174
175
176

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
177
    Takes an input tensor in CHW layout and saves it in a JPEG file.
178

179
    Args:
180
181
        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
            channels, where ``c`` must be 1 or 3.
182
183
184
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
            between 1 and 100. Default: 75
185
    """
186
187
    output = encode_jpeg(input, quality)
    write_file(filename, output)
Francisco Massa's avatar
Francisco Massa committed
188
189


190
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
191
192
    """
    Detects whether an image is a JPEG or PNG and performs the appropriate
193
    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Francisco Massa's avatar
Francisco Massa committed
194

195
    Optionally converts the image to the desired format.
196
197
    The values of the output tensor are uint8 in [0, 255], except for
    16-bits pngs which are int32 tensors in [0, 65535].
Francisco Massa's avatar
Francisco Massa committed
198

199
200
201
202
    Args:
        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
            PNG or JPEG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
203
204
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
205
206
207
208
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
209
    """
210
    output = torch.ops.image.decode_image(input, mode.value)
Francisco Massa's avatar
Francisco Massa committed
211
212
213
    return output


214
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Francisco Massa's avatar
Francisco Massa committed
215
    """
216
    Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
217
    Optionally converts the image to the desired format.
218
219
    The values of the output tensor are uint8 in [0, 255], except for
    16-bits pngs which are int32 tensors in [0, 65535].
220

221
222
223
    Args:
        path (str): path of the JPEG or PNG image.
        mode (ImageReadMode): the read mode used for optionally converting the image.
224
225
            Default: ``ImageReadMode.UNCHANGED``.
            See ``ImageReadMode`` class for more information on various
226
227
228
229
            available modes.

    Returns:
        output (Tensor[image_channels, image_height, image_width])
Francisco Massa's avatar
Francisco Massa committed
230
    """
Francisco Massa's avatar
Francisco Massa committed
231
    data = read_file(path)
232
    return decode_image(data, mode)