Unverified Commit 7621a8ed authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add namespace on image C++ codebase (#3312)

* Moving jpegcommon inside cpu implementation

* Adding namespaces on image and moving private methods to anonymous.

* Fixing headers.

* Renaming public image methods to match the ones on python.

* Refactoring to remove the double ifs in common_jpeg.h
parent e95a3d22
#if JPEG_FOUND
#include "jpegcommon.h"
#include "common_jpeg.h"
namespace vision {
namespace image {
namespace detail {
#if JPEG_FOUND
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
......@@ -16,3 +20,7 @@ void torch_jpeg_error_exit(j_common_ptr cinfo) {
longjmp(myerr->setjmp_buffer, 1);
}
#endif
} // namespace detail
} // namespace image
} // namespace vision
#pragma once
// clang-format off
#include <cstdio>
#include <cstddef>
// clang-format on
#if JPEG_FOUND
#include <stdio.h>
#include <jpeglib.h>
#include <setjmp.h>
namespace vision {
namespace image {
namespace detail {
static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
......@@ -19,4 +20,8 @@ struct torch_jpeg_error_mgr {
using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*;
void torch_jpeg_error_exit(j_common_ptr cinfo);
} // namespace detail
} // namespace image
} // namespace vision
#endif
#pragma once
#if PNG_FOUND
#include <png.h>
#include <setjmp.h>
#endif
#include "read_image_impl.h"
#include "decode_image.h"
#include "readjpeg_impl.h"
#include "readpng_impl.h"
#include "decode_jpeg.h"
#include "decode_png.h"
namespace vision {
namespace image {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
......@@ -17,9 +20,9 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data, mode);
return decode_jpeg(data, mode);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data, mode);
return decode_png(data, mode);
} else {
TORCH_CHECK(
false,
......@@ -27,3 +30,6 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
"are currently supported.");
}
}
} // namespace image
} // namespace vision
......@@ -3,6 +3,12 @@
#include <torch/types.h>
#include "../image_read_mode.h"
namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
} // namespace image
} // namespace vision
#include "readjpeg_impl.h"
#include "decode_jpeg.h"
#include "common_jpeg.h"
namespace vision {
namespace image {
#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
false, "decode_jpeg: torchvision not compiled with libjpeg support");
}
#else
#include "../jpegcommon.h"
using namespace detail;
namespace {
struct torch_jpeg_mgr {
struct jpeg_source_mgr pub;
......@@ -64,7 +71,9 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
} // namespace
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......@@ -146,4 +155,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
return tensor.permute({2, 0, 1});
}
#endif // JPEG_FOUND
#endif
} // namespace image
} // namespace vision
......@@ -3,6 +3,12 @@
#include <torch/types.h>
#include "../image_read_mode.h"
C10_EXPORT torch::Tensor decodePNG(
namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
} // namespace image
} // namespace vision
#include "readpng_impl.h"
#include "decode_png.h"
#include "common_png.h"
namespace vision {
namespace image {
#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
#include <setjmp.h>
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......@@ -160,4 +163,7 @@ torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
}
#endif // PNG_FOUND
#endif
} // namespace image
} // namespace vision
......@@ -3,6 +3,12 @@
#include <torch/types.h>
#include "../image_read_mode.h"
C10_EXPORT torch::Tensor decodeJPEG(
namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
} // namespace image
} // namespace vision
#include "writejpeg_impl.h"
#include "encode_jpeg.h"
#include "common_jpeg.h"
namespace vision {
namespace image {
#if !JPEG_FOUND
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
TORCH_CHECK(
false, "encodeJPEG: torchvision not compiled with libjpeg support");
false, "encode_jpeg: torchvision not compiled with libjpeg support");
}
#else
#include "../jpegcommon.h"
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
using namespace detail;
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling
struct jpeg_compress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
......@@ -98,3 +104,6 @@ torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
return outTensor;
}
#endif
} // namespace image
} // namespace vision
#pragma once
#include <torch/types.h>
namespace vision {
namespace image {
C10_EXPORT torch::Tensor encode_jpeg(
const torch::Tensor& data,
int64_t quality);
} // namespace image
} // namespace vision
#include "writejpeg_impl.h"
#include "encode_jpeg.h"
#include "common_png.h"
namespace vision {
namespace image {
#if !PNG_FOUND
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
TORCH_CHECK(false, "encodePNG: torchvision not compiled with libpng support");
torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
TORCH_CHECK(
false, "encode_png: torchvision not compiled with libpng support");
}
#else
#include <png.h>
#include <setjmp.h>
namespace {
struct torch_mem_encode {
char* buffer;
......@@ -59,7 +65,9 @@ void torch_png_write_data(
p->size += length;
}
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
} // namespace
torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
// Define compression structures and error handling
png_structp png_write;
png_infop info_ptr;
......@@ -171,3 +179,6 @@ torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
}
#endif
} // namespace image
} // namespace vision
......@@ -2,6 +2,12 @@
#include <torch/types.h>
C10_EXPORT torch::Tensor encodePNG(
namespace vision {
namespace image {
C10_EXPORT torch::Tensor encode_png(
const torch::Tensor& data,
int64_t compression_level);
} // namespace image
} // namespace vision
#include "read_write_file_impl.h"
#include "read_write_file.h"
#include <sys/stat.h>
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#endif
namespace vision {
namespace image {
#ifdef _WIN32
namespace {
std::wstring utf8_decode(const std::string& str) {
if (str.empty()) {
return std::wstring();
......@@ -21,6 +29,7 @@ std::wstring utf8_decode(const std::string& str) {
size_needed);
return wstrTo;
}
} // namespace
#endif
torch::Tensor read_file(const std::string& filename) {
......@@ -90,3 +99,6 @@ void write_file(const std::string& filename, torch::Tensor& data) {
fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile);
}
} // namespace image
} // namespace vision
#pragma once
#include <sys/stat.h>
#include <torch/types.h>
namespace vision {
namespace image {
C10_EXPORT torch::Tensor read_file(const std::string& filename);
C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data);
} // namespace image
} // namespace vision
#pragma once
#include <torch/types.h>
C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality);
......@@ -11,11 +11,17 @@ PyMODINIT_FUNC PyInit_image(void) {
}
#endif
namespace vision {
namespace image {
static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG)
.op("image::encode_png", &encodePNG)
.op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG)
.op("image::decode_png", &decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg", &decode_jpeg)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
} // namespace image
} // namespace vision
#pragma once
#include "cpu/read_image_impl.h"
#include "cpu/read_write_file_impl.h"
#include "cpu/readjpeg_impl.h"
#include "cpu/readpng_impl.h"
#include "cpu/writejpeg_impl.h"
#include "cpu/writepng_impl.h"
#include "cpu/decode_image.h"
#include "cpu/decode_jpeg.h"
#include "cpu/decode_png.h"
#include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h"
#include "cpu/read_write_file.h"
#pragma once
#include <stdint.h>
namespace vision {
namespace image {
/* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t;
const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0;
......@@ -7,3 +12,6 @@ const ImageReadMode IMAGE_READ_MODE_GRAY = 1;
const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;
} // namespace image
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment