Unverified Commit 7621a8ed authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add namespace on image C++ codebase (#3312)

* Moving jpegcommon inside cpu implementation

* Adding namespaces on image and moving private methods to anonymous.

* Fixing headers.

* Renaming public image methods to match the ones on python.

* Refactoring to remove the double ifs in common_jpeg.h
parent e95a3d22
#if JPEG_FOUND #include "common_jpeg.h"
#include "jpegcommon.h"
namespace vision {
namespace image {
namespace detail {
#if JPEG_FOUND
void torch_jpeg_error_exit(j_common_ptr cinfo) { void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce /* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */ * pointer */
...@@ -16,3 +20,7 @@ void torch_jpeg_error_exit(j_common_ptr cinfo) { ...@@ -16,3 +20,7 @@ void torch_jpeg_error_exit(j_common_ptr cinfo) {
longjmp(myerr->setjmp_buffer, 1); longjmp(myerr->setjmp_buffer, 1);
} }
#endif #endif
} // namespace detail
} // namespace image
} // namespace vision
#pragma once #pragma once
// clang-format off
#include <cstdio>
#include <cstddef>
// clang-format on
#if JPEG_FOUND #if JPEG_FOUND
#include <stdio.h>
#include <jpeglib.h> #include <jpeglib.h>
#include <setjmp.h> #include <setjmp.h>
namespace vision {
namespace image {
namespace detail {
static const JOCTET EOI_BUFFER[1] = {JPEG_EOI}; static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr { struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */ struct jpeg_error_mgr pub; /* "public" fields */
...@@ -19,4 +20,8 @@ struct torch_jpeg_error_mgr { ...@@ -19,4 +20,8 @@ struct torch_jpeg_error_mgr {
using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*; using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*;
void torch_jpeg_error_exit(j_common_ptr cinfo); void torch_jpeg_error_exit(j_common_ptr cinfo);
} // namespace detail
} // namespace image
} // namespace vision
#endif #endif
#pragma once
#if PNG_FOUND
#include <png.h>
#include <setjmp.h>
#endif
#include "read_image_impl.h" #include "decode_image.h"
#include "readjpeg_impl.h" #include "decode_jpeg.h"
#include "readpng_impl.h" #include "decode_png.h"
namespace vision {
namespace image {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
...@@ -17,9 +20,9 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { ...@@ -17,9 +20,9 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) { if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data, mode); return decode_jpeg(data, mode);
} else if (memcmp(png_signature, datap, 4) == 0) { } else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data, mode); return decode_png(data, mode);
} else { } else {
TORCH_CHECK( TORCH_CHECK(
false, false,
...@@ -27,3 +30,6 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { ...@@ -27,3 +30,6 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
"are currently supported."); "are currently supported.");
} }
} }
} // namespace image
} // namespace vision
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include <torch/types.h> #include <torch/types.h>
#include "../image_read_mode.h" #include "../image_read_mode.h"
namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_image( C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data, const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
} // namespace image
} // namespace vision
#include "readjpeg_impl.h" #include "decode_jpeg.h"
#include "common_jpeg.h"
namespace vision {
namespace image {
#if !JPEG_FOUND #if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK( TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support"); false, "decode_jpeg: torchvision not compiled with libjpeg support");
} }
#else #else
#include "../jpegcommon.h"
using namespace detail;
namespace {
struct torch_jpeg_mgr { struct torch_jpeg_mgr {
struct jpeg_source_mgr pub; struct jpeg_source_mgr pub;
...@@ -64,7 +71,9 @@ static void torch_jpeg_set_source_mgr( ...@@ -64,7 +71,9 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data; src->pub.next_input_byte = src->data;
} }
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { } // namespace
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
...@@ -146,4 +155,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { ...@@ -146,4 +155,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
return tensor.permute({2, 0, 1}); return tensor.permute({2, 0, 1});
} }
#endif // JPEG_FOUND #endif
} // namespace image
} // namespace vision
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include <torch/types.h> #include <torch/types.h>
#include "../image_read_mode.h" #include "../image_read_mode.h"
C10_EXPORT torch::Tensor decodePNG( namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_jpeg(
const torch::Tensor& data, const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
} // namespace image
} // namespace vision
#include "readpng_impl.h" #include "decode_png.h"
#include "common_png.h"
namespace vision {
namespace image {
#if !PNG_FOUND #if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support"); TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
} }
#else #else
#include <png.h>
#include <setjmp.h>
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
...@@ -160,4 +163,7 @@ torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { ...@@ -160,4 +163,7 @@ torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1}); return tensor.permute({2, 0, 1});
} }
#endif // PNG_FOUND #endif
} // namespace image
} // namespace vision
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include <torch/types.h> #include <torch/types.h>
#include "../image_read_mode.h" #include "../image_read_mode.h"
C10_EXPORT torch::Tensor decodeJPEG( namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data, const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
} // namespace image
} // namespace vision
#include "writejpeg_impl.h" #include "encode_jpeg.h"
#include "common_jpeg.h"
namespace vision {
namespace image {
#if !JPEG_FOUND #if !JPEG_FOUND
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) { torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
TORCH_CHECK( TORCH_CHECK(
false, "encodeJPEG: torchvision not compiled with libjpeg support"); false, "encode_jpeg: torchvision not compiled with libjpeg support");
} }
#else #else
#include "../jpegcommon.h"
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) { using namespace detail;
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling // Define compression structures and error handling
struct jpeg_compress_struct cinfo; struct jpeg_compress_struct cinfo;
struct torch_jpeg_error_mgr jerr; struct torch_jpeg_error_mgr jerr;
...@@ -98,3 +104,6 @@ torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) { ...@@ -98,3 +104,6 @@ torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
return outTensor; return outTensor;
} }
#endif #endif
} // namespace image
} // namespace vision
#pragma once
#include <torch/types.h>
namespace vision {
namespace image {
C10_EXPORT torch::Tensor encode_jpeg(
const torch::Tensor& data,
int64_t quality);
} // namespace image
} // namespace vision
#include "writejpeg_impl.h" #include "encode_jpeg.h"
#include "common_png.h"
namespace vision {
namespace image {
#if !PNG_FOUND #if !PNG_FOUND
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) { torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
TORCH_CHECK(false, "encodePNG: torchvision not compiled with libpng support"); TORCH_CHECK(
false, "encode_png: torchvision not compiled with libpng support");
} }
#else #else
#include <png.h>
#include <setjmp.h> namespace {
struct torch_mem_encode { struct torch_mem_encode {
char* buffer; char* buffer;
...@@ -59,7 +65,9 @@ void torch_png_write_data( ...@@ -59,7 +65,9 @@ void torch_png_write_data(
p->size += length; p->size += length;
} }
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) { } // namespace
torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
// Define compression structures and error handling // Define compression structures and error handling
png_structp png_write; png_structp png_write;
png_infop info_ptr; png_infop info_ptr;
...@@ -171,3 +179,6 @@ torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) { ...@@ -171,3 +179,6 @@ torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
} }
#endif #endif
} // namespace image
} // namespace vision
...@@ -2,6 +2,12 @@ ...@@ -2,6 +2,12 @@
#include <torch/types.h> #include <torch/types.h>
C10_EXPORT torch::Tensor encodePNG( namespace vision {
namespace image {
C10_EXPORT torch::Tensor encode_png(
const torch::Tensor& data, const torch::Tensor& data,
int64_t compression_level); int64_t compression_level);
} // namespace image
} // namespace vision
#include "read_write_file_impl.h" #include "read_write_file.h"
#include <sys/stat.h>
#ifdef _WIN32 #ifdef _WIN32
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
#include <Windows.h> #include <Windows.h>
#endif
namespace vision {
namespace image {
#ifdef _WIN32
namespace {
std::wstring utf8_decode(const std::string& str) { std::wstring utf8_decode(const std::string& str) {
if (str.empty()) { if (str.empty()) {
return std::wstring(); return std::wstring();
...@@ -21,6 +29,7 @@ std::wstring utf8_decode(const std::string& str) { ...@@ -21,6 +29,7 @@ std::wstring utf8_decode(const std::string& str) {
size_needed); size_needed);
return wstrTo; return wstrTo;
} }
} // namespace
#endif #endif
torch::Tensor read_file(const std::string& filename) { torch::Tensor read_file(const std::string& filename) {
...@@ -90,3 +99,6 @@ void write_file(const std::string& filename, torch::Tensor& data) { ...@@ -90,3 +99,6 @@ void write_file(const std::string& filename, torch::Tensor& data) {
fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile); fclose(outfile);
} }
} // namespace image
} // namespace vision
#pragma once #pragma once
#include <sys/stat.h>
#include <torch/types.h> #include <torch/types.h>
namespace vision {
namespace image {
C10_EXPORT torch::Tensor read_file(const std::string& filename); C10_EXPORT torch::Tensor read_file(const std::string& filename);
C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data);
} // namespace image
} // namespace vision
#pragma once
#include <torch/types.h>
C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality);
...@@ -11,11 +11,17 @@ PyMODINIT_FUNC PyInit_image(void) { ...@@ -11,11 +11,17 @@ PyMODINIT_FUNC PyInit_image(void) {
} }
#endif #endif
namespace vision {
namespace image {
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG) .op("image::decode_png", &decode_png)
.op("image::encode_png", &encodePNG) .op("image::encode_png", &encode_png)
.op("image::decode_jpeg", &decodeJPEG) .op("image::decode_jpeg", &decode_jpeg)
.op("image::encode_jpeg", &encodeJPEG) .op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file) .op("image::read_file", &read_file)
.op("image::write_file", &write_file) .op("image::write_file", &write_file)
.op("image::decode_image", &decode_image); .op("image::decode_image", &decode_image);
} // namespace image
} // namespace vision
#pragma once #pragma once
#include "cpu/read_image_impl.h" #include "cpu/decode_image.h"
#include "cpu/read_write_file_impl.h" #include "cpu/decode_jpeg.h"
#include "cpu/readjpeg_impl.h" #include "cpu/decode_png.h"
#include "cpu/readpng_impl.h" #include "cpu/encode_jpeg.h"
#include "cpu/writejpeg_impl.h" #include "cpu/encode_png.h"
#include "cpu/writepng_impl.h" #include "cpu/read_write_file.h"
#pragma once #pragma once
#include <stdint.h>
namespace vision {
namespace image {
/* Should be kept in-sync with Python ImageReadMode enum */ /* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t; using ImageReadMode = int64_t;
const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0; const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0;
...@@ -7,3 +12,6 @@ const ImageReadMode IMAGE_READ_MODE_GRAY = 1; ...@@ -7,3 +12,6 @@ const ImageReadMode IMAGE_READ_MODE_GRAY = 1;
const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3; const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;
} // namespace image
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment