Unverified Commit 766721b1 authored by Edgar Andrés Margffoy Tuay's avatar Edgar Andrés Margffoy Tuay Committed by GitHub
Browse files

PR: Add libpng and libjpeg-turbo requirement into conda recipe (#2301)



* Add libpng requirement into conda recipe

* Try to install libjpeg-turbo

* Add PNG reading capabilities

* Remove newline

* Add image extension to compilation instructions

* Include png functions as part of the main library

* Update CMakeLists

* Detect if building on conda-build

* Debug

* More debug messages

* Print globbed libreries

* Print globbed libreries

* Point to correct PNG path

* Remove libJPEG preventively

* Debug extension loading

* Link libpng explicitly

* Link with PNG

* Add PNG reading capabilities

* Add libpng requirement into conda recipe

* Try to install libjpeg-turbo

* Remove newline

* Add image extension to compilation instructions

* Include png functions as part of the main library

* Update CMakeLists

* Detect if building on conda-build

* Debug

* More debug messages

* Print globbed libreries

* Print globbed libreries

* Point to correct PNG path

* Remove libJPEG preventively

* Debug extension loading

* Link libpng explicitly

* Link with PNG

* Install libpng on conda-based wheel distributions

* Add -y flag

* Add -y flag to yum

* Locate LibPNG on windows conda

* Remove empty else

* Copy libpng16.so

* Copy dylib on Mac

* Improve check on Windows

* Try to install ninja using conda on windows

* Use libpng on Windows

* Package lib on windows wheel

* Point library to the correct place

* Include binaries as part of wheel

* Copy libpng.so on linux

* Look for png.h on Windows when using conda-build

* Do not skip png tests on Mac/Win

* Restore libjpeg-turbo

* Install jpeg-turbo on wheel distributions

* Install libjpeg-turbo from conda-forge on wheel distributions

* Do not pull av on conda-build

* Add pillow disclaimer

* Vendors libjpeg-turbo 2.0.4

* Merge JPEG work

* Remove submodules

* Regenerate circle config

* Fix style issues

* Fix C++ style issues

* More style corrections

* Add JPEG-turbo to linking libraries

* More style corrections

* More style corrections

* More style corrections

* Install libjpeg-turbo-devel

* Install libturbo-jpeg on typing pipeline

* Update Circle template

* Windows and Unix turbojpeg have the same linking name

* Install turbojpeg-devel instead of libjpeg-turbo

* Copy TurboJPEG binaries to wheel

* Move test image

* Move back test image

* Update JPEG test path

* Remove dot from extension

* Move image functions to extension

* Use stdout arg in subprocess

* Disable image extension if libpng or turbojpeg are not found

* Append libpng stdout

* Prevent list appending on lists

* Minor path correction

* Minor error correction

* Add linking flags

* Style issues correction

* Address minor review corrections

* Refactor library search

* Restore access index

* Fix JPEG tests

* Update libpng version in Travis

* Add -y flag

* Remove dot

* Update libpng using apt

* Check libpng version

* Change libturbojpeg binary

* Update import

* Change call

* Restore av in conda recipe

* Minor error correction

* Remove unused comment in travis.yml

* Update README

* Fix missing links

* Remove fixes for 16.04
Co-authored-by: default avatarRyad ZENINE <r.zenine@gmail.com>
parent 9d9e716f
...@@ -107,6 +107,8 @@ jobs: ...@@ -107,6 +107,8 @@ jobs:
- checkout - checkout
- run: - run:
command: | command: |
sudo apt-get update -y
sudo apt install -y libturbojpeg-dev
pip install --user --progress-bar off numpy mypy pip install --user --progress-bar off numpy mypy
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off --editable . pip install --user --progress-bar off --editable .
......
...@@ -107,6 +107,8 @@ jobs: ...@@ -107,6 +107,8 @@ jobs:
- checkout - checkout
- run: - run:
command: | command: |
sudo apt-get update -y
sudo apt install -y libturbojpeg-dev
pip install --user --progress-bar off numpy mypy pip install --user --progress-bar off numpy mypy
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off --editable . pip install --user --progress-bar off --editable .
......
...@@ -13,6 +13,7 @@ jobs: ...@@ -13,6 +13,7 @@ jobs:
before_install: before_install:
- sudo apt-get update - sudo apt-get update
- sudo apt-get install -y libpng-dev libturbojpeg-dev
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda - bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH" - export PATH="$HOME/miniconda/bin:$PATH"
......
...@@ -11,22 +11,28 @@ if(WITH_CUDA) ...@@ -11,22 +11,28 @@ if(WITH_CUDA)
endif() endif()
find_package(Python3 COMPONENTS Development) find_package(Python3 COMPONENTS Development)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
find_package(PNG REQUIRED)
file(GLOB HEADERS torchvision/csrc/*.h) file(GLOB HEADERS torchvision/csrc/*.h)
file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp torchvision/csrc/*.cpp) file(GLOB IMAGE_HEADERS torchvision/csrc/cpu/image/*.h)
file(GLOB IMAGE_SOURCES torchvision/csrc/cpu/image/*.cpp)
file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp ${IMAGE_HEADERS} ${IMAGE_SOURCES} ${HEADERS} torchvision/csrc/*.cpp)
if(WITH_CUDA) if(WITH_CUDA)
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} torchvision/csrc/cuda/*.h torchvision/csrc/cuda/*.cu) file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} torchvision/csrc/cuda/*.h torchvision/csrc/cuda/*.cu)
endif() endif()
file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h) file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h)
file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)
add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES}) add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES} ${IMAGE_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python3::Python) target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${PNG_LIBRARY} Python3::Python)
# target_link_libraries(${PROJECT_NAME} PRIVATE ${PNG_LIBRARY} Python3::Python)
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision) set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision)
target_include_directories(${PROJECT_NAME} INTERFACE target_include_directories(${PROJECT_NAME} INTERFACE
$<BUILD_INTERFACE:${HEADERS}> $<BUILD_INTERFACE:${HEADERS}:${PNG_INCLUDE_DIR}>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>) $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>)
include(GNUInstallDirs) include(GNUInstallDirs)
......
...@@ -78,6 +78,15 @@ Torchvision currently supports the following image backends: ...@@ -78,6 +78,15 @@ Torchvision currently supports the following image backends:
* `accimage`_ - if installed can be activated by calling :code:`torchvision.set_image_backend('accimage')` * `accimage`_ - if installed can be activated by calling :code:`torchvision.set_image_backend('accimage')`
* `libpng`_ - can be installed via conda :code:`conda install libpng` or any of the package managers for debian-based and RHEL-based Linux distributions.
* `libturbojpeg`_ - blazing speed, fast JPEG image loading. Can be installed from conda-forge :code:`conda install libjpeg-turbo -c conda-forge`.
**Notes:** ``libpng`` and ``libturbojpeg`` must be available at compilation time in order to be available. Also, most linux distributions distinguish between
``libturbojpeg`` and ``libjpeg-turbo``, where the former should be installed instead of the latter one.
.. _libpng : http://www.libpng.org/pub/png/libpng.html
.. _libturbojpeg: https://github.com/libjpeg-turbo/libjpeg-turbo
.. _Pillow : https://python-pillow.org/ .. _Pillow : https://python-pillow.org/
.. _Pillow-SIMD : https://github.com/uploadcare/pillow-simd .. _Pillow-SIMD : https://github.com/uploadcare/pillow-simd
.. _accimage: https://github.com/pytorch/accimage .. _accimage: https://github.com/pytorch/accimage
......
...@@ -10,6 +10,30 @@ setup_wheel_python ...@@ -10,6 +10,30 @@ setup_wheel_python
pip_install numpy pyyaml future ninja pip_install numpy pyyaml future ninja
setup_pip_pytorch_version setup_pip_pytorch_version
python setup.py clean python setup.py clean
# Copy binaries to be included in the wheel distribution
if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then
python_exec="$(which python)"
bin_path=$(dirname $python_exec)
env_path=$(dirname $bin_path)
if [[ "$(uname)" == Darwin ]]; then
# Include LibPNG
cp "$env_path/lib/libpng16.dylib" torchvision
# Include TurboJPEG
cp "$env_path/lib/libturbojpeg.dylib" torchvision
else
# Include libPNG
cp "$bin_path/Library/lib/libpng.lib" torchvision
# Include TurboJPEG
cp "$bin_path/Library/lib/turbojpeg.lib" torchvision
fi
else
# Include LibPNG
cp "/usr/lib64/libpng.so" torchvision
# Include TurboJPEG
cp "/usr/lib64/libturbojpeg.so" torchvision
fi
if [[ "$OSTYPE" == "msys" ]]; then if [[ "$OSTYPE" == "msys" ]]; then
IS_WHEEL=1 "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel IS_WHEEL=1 "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel
else else
......
...@@ -170,7 +170,13 @@ setup_wheel_python() { ...@@ -170,7 +170,13 @@ setup_wheel_python() {
conda env remove -n "env$PYTHON_VERSION" || true conda env remove -n "env$PYTHON_VERSION" || true
conda create -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" conda create -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION"
conda activate "env$PYTHON_VERSION" conda activate "env$PYTHON_VERSION"
# Install libPNG from Anaconda (defaults)
conda install libpng -y
# Install libJPEG-turbo from conda-forge
conda install -y libjpeg-turbo -c conda-forge
else else
# Install native CentOS libPNG, libJPEG-turbo
yum install -y libpng-devel turbojpeg-devel
case "$PYTHON_VERSION" in case "$PYTHON_VERSION" in
2.7) 2.7)
if [[ -n "$UNICODE_ABI" ]]; then if [[ -n "$UNICODE_ABI" ]]; then
......
channel_sources:
- defaults,conda-forge
blas_impl: blas_impl:
- mkl # [x86_64] - mkl # [x86_64]
c_compiler: c_compiler:
......
...@@ -8,6 +8,8 @@ source: ...@@ -8,6 +8,8 @@ source:
requirements: requirements:
build: build:
- {{ compiler('c') }} # [win] - {{ compiler('c') }} # [win]
- libpng
- libjpeg-turbo
host: host:
- python - python
...@@ -18,6 +20,10 @@ requirements: ...@@ -18,6 +20,10 @@ requirements:
run: run:
- python - python
- libpng
- libjpeg-turbo
# Pillow introduces unwanted conflicts with libjpeg-turbo, as it depends on jpeg
# The fix depends on https://github.com/conda-forge/conda-forge.github.io/issues/673
- pillow >=4.1.1 - pillow >=4.1.1
- numpy >=1.11 - numpy >=1.11
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
......
...@@ -3,7 +3,7 @@ import io ...@@ -3,7 +3,7 @@ import io
import re import re
import sys import sys
from setuptools import setup, find_packages from setuptools import setup, find_packages
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import parse_version, get_distribution, DistributionNotFound
import subprocess import subprocess
import distutils.command.clean import distutils.command.clean
import distutils.spawn import distutils.spawn
...@@ -76,7 +76,76 @@ pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' ...@@ -76,7 +76,76 @@ pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver) requirements.append(pillow_req + pillow_ver)
def find_library(name, vision_include):
this_dir = os.path.dirname(os.path.abspath(__file__))
build_prefix = os.environ.get('BUILD_PREFIX', None)
is_conda_build = build_prefix is not None
library_found = False
conda_installed = False
lib_folder = None
include_folder = None
library_header = '{0}.h'.format(name)
print('Running build on conda-build: {0}'.format(is_conda_build))
if is_conda_build:
# Add conda headers/libraries
if os.name == 'nt':
build_prefix = os.path.join(build_prefix, 'Library')
include_folder = os.path.join(build_prefix, 'include')
lib_folder = os.path.join(build_prefix, 'lib')
library_header_path = os.path.join(
include_folder, library_header)
library_found = os.path.isfile(library_header_path)
conda_installed = library_found
else:
# Check if using Anaconda to produce wheels
conda = distutils.spawn.find_executable('conda')
is_conda = conda is not None
print('Running build on conda: {0}'.format(is_conda))
if is_conda:
python_executable = sys.executable
py_folder = os.path.dirname(python_executable)
if os.name == 'nt':
env_path = os.path.join(py_folder, 'Library')
else:
env_path = os.path.dirname(py_folder)
lib_folder = os.path.join(env_path, 'lib')
include_folder = os.path.join(env_path, 'include')
library_header_path = os.path.join(
include_folder, library_header)
library_found = os.path.isfile(library_header_path)
conda_installed = library_found
# Try to locate turbojpeg in Linux standard paths
if not library_found:
if sys.platform == 'linux':
library_found = os.path.exists('/usr/include/{0}'.format(
library_header))
library_found = library_found or os.path.exists(
'/usr/local/include/{0}'.format(library_header))
else:
# Lookup in TORCHVISION_INCLUDE or in the package file
package_path = [os.path.join(this_dir, 'torchvision')]
for folder in vision_include + package_path:
candidate_path = os.path.join(folder, library_header)
library_found = os.path.exists(candidate_path)
if library_found:
break
return library_found, conda_installed, include_folder, lib_folder
def get_extensions(): def get_extensions():
vision_include = os.environ.get('TORCHVISION_INCLUDE', None)
vision_library = os.environ.get('TORCHVISION_LIBRARY', None)
vision_include = (vision_include.split(os.pathsep)
if vision_include is not None else [])
vision_library = (vision_library.split(os.pathsep)
if vision_library is not None else [])
include_dirs = vision_include
library_dirs = vision_library
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')
...@@ -149,13 +218,14 @@ def get_extensions(): ...@@ -149,13 +218,14 @@ def get_extensions():
sources = [os.path.join(extensions_dir, s) for s in sources] sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir] include_dirs += [extensions_dir]
ext_modules = [ ext_modules = [
extension( extension(
'torchvision._C', 'torchvision._C',
sources, sources,
include_dirs=include_dirs, include_dirs=include_dirs,
library_dirs=library_dirs,
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
) )
...@@ -171,6 +241,65 @@ def get_extensions(): ...@@ -171,6 +241,65 @@ def get_extensions():
) )
) )
# Image reading extension
image_macros = []
image_include = [extensions_dir]
image_library = []
image_link_flags = []
# Locating libPNG
libpng = distutils.spawn.find_executable('libpng-config')
png_found = libpng is not None
image_macros += [('PNG_FOUND', str(int(png_found)))]
print('PNG found: {0}'.format(png_found))
if png_found:
png_version = subprocess.run([libpng, '--version'],
stdout=subprocess.PIPE)
png_version = png_version.stdout.strip().decode('utf-8')
print('libpng version: {0}'.format(png_version))
png_version = parse_version(png_version)
if png_version >= parse_version("1.6.0"):
print('Building torchvision with PNG image support')
png_lib = subprocess.run([libpng, '--libdir'],
stdout=subprocess.PIPE)
png_include = subprocess.run([libpng, '--I_opts'],
stdout=subprocess.PIPE)
image_library += [png_lib.stdout.strip().decode('utf-8')]
image_include += [png_include.stdout.strip().decode('utf-8')]
image_link_flags.append('png' if os.name != 'nt' else 'libpng')
else:
print('libpng installed version is less than 1.6.0, '
'disabling PNG support')
png_found = False
# Locating libjpegturbo
turbojpeg_info = find_library('turbojpeg', vision_include)
(turbojpeg_found, conda_installed,
turbo_include_folder, turbo_lib_folder) = turbojpeg_info
image_macros += [('JPEG_FOUND', str(int(turbojpeg_found)))]
print('turboJPEG found: {0}'.format(turbojpeg_found))
if turbojpeg_found:
print('Building torchvision with JPEG image support')
image_link_flags.append('turbojpeg')
if conda_installed:
image_library += [turbo_lib_folder]
image_include += [turbo_include_folder]
image_path = os.path.join(extensions_dir, 'cpu', 'image')
image_src = glob.glob(os.path.join(image_path, '*.cpp'))
if png_found or turbojpeg_found:
ext_modules.append(extension(
'torchvision.image',
image_src,
include_dirs=include_dirs + [image_path] + image_include,
library_dirs=library_dirs + image_library,
define_macros=image_macros,
libraries=image_link_flags,
extra_compile_args=extra_compile_args
))
ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
has_ffmpeg = ffmpeg_exe is not None has_ffmpeg = ffmpeg_exe is not None
...@@ -243,7 +372,9 @@ setup( ...@@ -243,7 +372,9 @@ setup(
# Package info # Package info
packages=find_packages(exclude=('test',)), packages=find_packages(exclude=('test',)),
package_data={
package_name: ['*.lib', '*.dylib', '*.so']
},
zip_safe=False, zip_safe=False,
install_requires=requirements, install_requires=requirements,
extras_require={ extras_require={
......
import os
import unittest
import sys
import torch
import torchvision
from PIL import Image
from torchvision.io.image import read_png, decode_png, read_jpeg, decode_jpeg
import numpy as np
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
class ImageTester(unittest.TestCase):
def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, "jpg"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_ljpeg = read_jpeg(img_path)
norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[2] * 255
err = torch.abs(img_ljpeg.flatten().float() - img_pil.flatten().float()).sum().float() / (norm)
self.assertLessEqual(err, 1e-2)
def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, "jpg"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
size = os.path.getsize(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[2] * 255
err = torch.abs(img_ljpeg.flatten().float() - img_pil.flatten().float()).sum().float() / (norm)
self.assertLessEqual(err, 1e-2)
with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."):
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
with self.assertRaisesRegex(RuntimeError, "Error while reading jpeg headers"):
decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_read_png(self):
for img_path in get_images(IMAGE_DIR, "png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_lpng = read_png(img_path)
self.assertEqual(img_lpng, img_pil)
def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, "png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertEqual(img_lpng, img_pil)
self.assertEqual(decode_png(torch.empty()), torch.empty())
self.assertEqual(decode_png(torch.randint(3, 5, (300,))), torch.empty())
if __name__ == '__main__':
unittest.main()
#include "image.h"
#include <ATen/ATen.h>
#include <Python.h>
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
#ifdef _WIN32
#if PY_MAJOR_VERSION < 3
PyMODINIT_FUNC init_image(void) {
// No need to do anything.
return NULL;
}
#else
PyMODINIT_FUNC PyInit_image(void) {
// No need to do anything.
return NULL;
}
#endif
#endif
static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG)
.op("image::decode_jpeg", &decodeJPEG);
#pragma once
#include <torch/script.h>
#include <torch/torch.h>
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include "readjpeg_cpu.h"
#include <ATen/ATen.h>
#include <setjmp.h>
#include <string>
#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data) {
AT_ERROR("decodeJPEG: torchvision not compiled with turboJPEG support");
}
#else
#include <turbojpeg.h>
torch::Tensor decodeJPEG(const torch::Tensor& data) {
tjhandle tjInstance = tjInitDecompress();
if (tjInstance == NULL) {
TORCH_CHECK(false, "libjpeg-turbo decompression initialization failed.");
}
auto datap = data.accessor<unsigned char, 1>().data();
int width, height;
if (tjDecompressHeader(tjInstance, datap, data.numel(), &width, &height) <
0) {
tjDestroy(tjInstance);
TORCH_CHECK(false, "Error while reading jpeg headers");
}
auto tensor =
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
int pixelFormat = TJPF_RGB;
auto ret = tjDecompress2(
tjInstance,
datap,
data.numel(),
ptr,
width,
0,
height,
pixelFormat,
NULL);
if (ret != 0) {
tjDestroy(tjInstance);
TORCH_CHECK(false, "decompressing JPEG image");
}
return tensor;
}
#endif // JPEG_FOUND
#pragma once
#include <torch/torch.h>
torch::Tensor decodeJPEG(const torch::Tensor& data);
#include "readpng_cpu.h"
#include <ATen/ATen.h>
#include <setjmp.h>
#include <string>
#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data) {
AT_ERROR("decodePNG: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
torch::Tensor decodePNG(const torch::Tensor& data) {
auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
auto info_ptr = png_create_info_struct(png_ptr);
if (!info_ptr) {
png_destroy_read_struct(&png_ptr, nullptr, nullptr);
// Seems redundant with the if statement. done here to avoid leaking memory.
TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
}
auto datap = data.accessor<unsigned char, 1>().data();
if (setjmp(png_jmpbuf(png_ptr)) != 0) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Internal error.");
}
auto is_png = !png_sig_cmp(datap, 0, 8);
TORCH_CHECK(is_png, "Content is not png!")
struct Reader {
png_const_bytep ptr;
} reader;
reader.ptr = png_const_bytep(datap) + 8;
auto read_callback =
[](png_structp png_ptr, png_bytep output, png_size_t bytes) {
auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
std::copy(reader->ptr, reader->ptr + bytes, output);
reader->ptr += bytes;
};
png_set_sig_bytes(png_ptr, 8);
png_set_read_fn(png_ptr, &reader, read_callback);
png_read_info(png_ptr, info_ptr);
png_uint_32 width, height;
int bit_depth, color_type;
auto retval = png_get_IHDR(
png_ptr,
info_ptr,
&width,
&height,
&bit_depth,
&color_type,
nullptr,
nullptr,
nullptr);
if (retval != 1) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
if (color_type != PNG_COLOR_TYPE_RGB) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")
}
auto tensor =
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (decltype(height) i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor;
}
#endif // PNG_FOUND
#pragma once
#include <torch/torch.h>
#include <string>
torch::Tensor decodePNG(const torch::Tensor& data);
...@@ -30,5 +30,5 @@ __all__ = [ ...@@ -30,5 +30,5 @@ __all__ = [
"_read_video_clip_from_memory", "_read_video_clip_from_memory",
"_read_video_meta_data", "_read_video_meta_data",
"VideoMetaData", "VideoMetaData",
"Timebase", "Timebase"
] ]
import torch
from torch import nn, Tensor
import os
import os.path as osp
import importlib
_HAS_IMAGE_OPT = False
try:
lib_dir = osp.join(osp.dirname(__file__), "..")
loader_details = (
importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES
)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec("image")
if ext_specs is not None:
torch.ops.load_library(ext_specs.origin)
_HAS_IMAGE_OPT = True
except (ImportError, OSError):
pass
def decode_png(input):
# type: (Tensor) -> Tensor
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional int8 tensor containing
the raw bytes of the PNG image.
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not isinstance(input, torch.Tensor) or len(input) == 0:
raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8:
raise ValueError("Expected a torch.uint8 tensor.")
output = torch.ops.image.decode_png(input)
return output
def read_png(path):
# type: (str) -> Tensor
"""
Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the PNG image.
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not os.path.isfile(path):
raise ValueError("Expected a valid file path.")
size = os.path.getsize(path)
if size == 0:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
return decode_png(data)
def decode_jpeg(input):
# type: (Tensor) -> Tensor
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional int8 tensor containing
the raw bytes of the JPEG image.
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1:
raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8:
raise ValueError("Expected a torch.uint8 tensor.")
output = torch.ops.image.decode_jpeg(input)
return output
def read_jpeg(path):
# type: (str) -> Tensor
"""
Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the JPEG image.
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not os.path.isfile(path):
raise ValueError("Expected a valid file path.")
size = os.path.getsize(path)
if size == 0:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
return decode_jpeg(data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment