image.py 4.68 KB
Newer Older
1
2
3
4
import torch

import os
import os.path as osp
5
import importlib.machinery
6
7
8
9
10
11
12
13
14
15
16

_HAS_IMAGE_OPT = False

try:
    lib_dir = osp.join(osp.dirname(__file__), "..")

    loader_details = (
        importlib.machinery.ExtensionFileLoader,
        importlib.machinery.EXTENSION_SUFFIXES
    )

17
    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)  # type: ignore[arg-type]
18
19
20
21
22
23
24
25
    ext_specs = extfinder.find_spec("image")
    if ext_specs is not None:
        torch.ops.load_library(ext_specs.origin)
        _HAS_IMAGE_OPT = True
except (ImportError, OSError):
    pass


26
def decode_png(input: torch.Tensor) -> torch.Tensor:
27
28
29
30
31
32
33
34
35
    """
    Decodes a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
        input (Tensor[1]): a one dimensional int8 tensor containing
    the raw bytes of the PNG image.

    Returns:
36
        output (Tensor[3, image_height, image_width])
37
    """
38
    if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1:  # type: ignore[attr-defined]
39
40
41
42
43
44
45
46
        raise ValueError("Expected a non empty 1-dimensional tensor.")

    if not input.dtype == torch.uint8:
        raise ValueError("Expected a torch.uint8 tensor.")
    output = torch.ops.image.decode_png(input)
    return output


47
def read_png(path: str) -> torch.Tensor:
48
49
50
51
52
53
54
55
    """
    Reads a PNG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.

    Arguments:
        path (str): path of the PNG image.

    Returns:
56
        output (Tensor[3, image_height, image_width])
57
58
59
60
61
62
63
64
65
    """
    if not os.path.isfile(path):
        raise ValueError("Expected a valid file path.")

    size = os.path.getsize(path)
    if size == 0:
        raise ValueError("Expected a non empty file.")
    data = torch.from_file(path, dtype=torch.uint8, size=size)
    return decode_png(data)
66
67


68
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
69
70
71
72
73
74
75
    """
    Decodes a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
        input (Tensor[1]): a one dimensional int8 tensor containing
    the raw bytes of the JPEG image.
    Returns:
76
        output (Tensor[3, image_height, image_width])
77
    """
78
    if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1:  # type: ignore[attr-defined]
79
80
81
82
83
84
85
86
87
        raise ValueError("Expected a non empty 1-dimensional tensor.")

    if not input.dtype == torch.uint8:
        raise ValueError("Expected a torch.uint8 tensor.")

    output = torch.ops.image.decode_jpeg(input)
    return output


88
def read_jpeg(path: str) -> torch.Tensor:
89
90
91
92
93
94
    """
    Reads a JPEG image into a 3 dimensional RGB Tensor.
    The values of the output tensor are uint8 between 0 and 255.
    Arguments:
        path (str): path of the JPEG image.
    Returns:
95
        output (Tensor[3, image_height, image_width])
96
97
98
99
100
101
102
103
104
    """
    if not os.path.isfile(path):
        raise ValueError("Expected a valid file path.")

    size = os.path.getsize(path)
    if size == 0:
        raise ValueError("Expected a non empty file.")
    data = torch.from_file(path, dtype=torch.uint8, size=size)
    return decode_jpeg(data)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and returns a buffer with the contents of its corresponding JPEG file.
    Arguments:
        input (Tensor[channels, image_height, image_width]): int8 image tensor
    of `c` channels, where `c` must be 1 or 3.
        quality (int): Quality of the resulting JPEG file, it must be a number
    between 1 and 100. Default: 75
    Returns
        output (Tensor[1]): A one dimensional int8 tensor that contains the raw
    bytes of the JPEG file.
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    output = torch.ops.image.encode_jpeg(input, quality)
    return output


def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
    """
    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
    and saves it in a JPEG file.
    Arguments:
        input (Tensor[channels, image_height, image_width]): int8 image tensor
    of `c` channels, where `c` must be 1 or 3.
        filename (str): Path to save the image.
        quality (int): Quality of the resulting JPEG file, it must be a number
    between 1 and 100. Default: 75
    """
    if quality < 1 or quality > 100:
        raise ValueError('Image quality should be a positive number '
                         'between 1 and 100')

    torch.ops.image.write_jpeg(input, filename, quality)