import torch from torch import nn, Tensor import os import os.path as osp import importlib _HAS_IMAGE_OPT = False try: lib_dir = osp.join(osp.dirname(__file__), "..") loader_details = ( importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES ) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) ext_specs = extfinder.find_spec("image") if ext_specs is not None: torch.ops.load_library(ext_specs.origin) _HAS_IMAGE_OPT = True except (ImportError, OSError): pass def decode_png(input): # type: (Tensor) -> Tensor """ Decodes a PNG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. Arguments: input (Tensor[1]): a one dimensional int8 tensor containing the raw bytes of the PNG image. Returns: output (Tensor[image_width, image_height, 3]) """ if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: raise ValueError("Expected a non empty 1-dimensional tensor.") if not input.dtype == torch.uint8: raise ValueError("Expected a torch.uint8 tensor.") output = torch.ops.image.decode_png(input) return output def read_png(path): # type: (str) -> Tensor """ Reads a PNG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. Arguments: path (str): path of the PNG image. Returns: output (Tensor[image_width, image_height, 3]) """ if not os.path.isfile(path): raise ValueError("Expected a valid file path.") size = os.path.getsize(path) if size == 0: raise ValueError("Expected a non empty file.") data = torch.from_file(path, dtype=torch.uint8, size=size) return decode_png(data)