Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
extern "C" {
#include <libavcodec/avcodec.h>
#include <libavcodec/bsf.h>
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
}
class Demuxer {
private:
AVFormatContext* fmtCtx = NULL;
AVBSFContext* bsfCtx = NULL;
AVPacket pkt, pktFiltered;
AVCodecID eVideoCodec;
uint8_t* dataWithHeader = NULL;
bool bMp4H264, bMp4HEVC, bMp4MPEG4;
unsigned int frameCount = 0;
int iVideoStream;
double timeBase = 0.0;
public:
Demuxer(const char* filePath) {
avformat_network_init();
TORCH_CHECK(
0 <= avformat_open_input(&fmtCtx, filePath, NULL, NULL),
"avformat_open_input() failed at line ",
__LINE__,
" in demuxer.h\n");
if (!fmtCtx) {
TORCH_CHECK(
false,
"Encountered NULL AVFormatContext at line ",
__LINE__,
" in demuxer.h\n");
}
TORCH_CHECK(
0 <= avformat_find_stream_info(fmtCtx, NULL),
"avformat_find_stream_info() failed at line ",
__LINE__,
" in demuxer.h\n");
iVideoStream =
av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0);
if (iVideoStream < 0) {
TORCH_CHECK(
false,
"av_find_best_stream() failed at line ",
__LINE__,
" in demuxer.h\n");
}
eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id;
AVRational rTimeBase = fmtCtx->streams[iVideoStream]->time_base;
timeBase = av_q2d(rTimeBase);
bMp4H264 = eVideoCodec == AV_CODEC_ID_H264 &&
(!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") ||
!strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") ||
!strcmp(fmtCtx->iformat->long_name, "Matroska / WebM"));
bMp4HEVC = eVideoCodec == AV_CODEC_ID_HEVC &&
(!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") ||
!strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") ||
!strcmp(fmtCtx->iformat->long_name, "Matroska / WebM"));
bMp4MPEG4 = eVideoCodec == AV_CODEC_ID_MPEG4 &&
(!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") ||
!strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") ||
!strcmp(fmtCtx->iformat->long_name, "Matroska / WebM"));
av_init_packet(&pkt);
pkt.data = NULL;
pkt.size = 0;
av_init_packet(&pktFiltered);
pktFiltered.data = NULL;
pktFiltered.size = 0;
if (bMp4H264) {
const AVBitStreamFilter* bsf = av_bsf_get_by_name("h264_mp4toannexb");
if (!bsf) {
TORCH_CHECK(
false,
"av_bsf_get_by_name() failed at line ",
__LINE__,
" in demuxer.h\n");
}
TORCH_CHECK(
0 <= av_bsf_alloc(bsf, &bsfCtx),
"av_bsf_alloc() failed at line ",
__LINE__,
" in demuxer.h\n");
avcodec_parameters_copy(
bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar);
TORCH_CHECK(
0 <= av_bsf_init(bsfCtx),
"av_bsf_init() failed at line ",
__LINE__,
" in demuxer.h\n");
}
if (bMp4HEVC) {
const AVBitStreamFilter* bsf = av_bsf_get_by_name("hevc_mp4toannexb");
if (!bsf) {
TORCH_CHECK(
false,
"av_bsf_get_by_name() failed at line ",
__LINE__,
" in demuxer.h\n");
}
TORCH_CHECK(
0 <= av_bsf_alloc(bsf, &bsfCtx),
"av_bsf_alloc() failed at line ",
__LINE__,
" in demuxer.h\n");
avcodec_parameters_copy(
bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar);
TORCH_CHECK(
0 <= av_bsf_init(bsfCtx),
"av_bsf_init() failed at line ",
__LINE__,
" in demuxer.h\n");
}
}
~Demuxer() {
if (!fmtCtx) {
return;
}
if (pkt.data) {
av_packet_unref(&pkt);
}
if (pktFiltered.data) {
av_packet_unref(&pktFiltered);
}
if (bsfCtx) {
av_bsf_free(&bsfCtx);
}
avformat_close_input(&fmtCtx);
if (dataWithHeader) {
av_free(dataWithHeader);
}
}
AVCodecID get_video_codec() {
return eVideoCodec;
}
double get_duration() const {
return (double)fmtCtx->duration / AV_TIME_BASE;
}
double get_fps() const {
return av_q2d(fmtCtx->streams[iVideoStream]->r_frame_rate);
}
bool demux(uint8_t** video, unsigned long* videoBytes) {
if (!fmtCtx) {
return false;
}
*videoBytes = 0;
if (pkt.data) {
av_packet_unref(&pkt);
}
int e = 0;
while ((e = av_read_frame(fmtCtx, &pkt)) >= 0 &&
pkt.stream_index != iVideoStream) {
av_packet_unref(&pkt);
}
if (e < 0) {
return false;
}
if (bMp4H264 || bMp4HEVC) {
if (pktFiltered.data) {
av_packet_unref(&pktFiltered);
}
TORCH_CHECK(
0 <= av_bsf_send_packet(bsfCtx, &pkt),
"av_bsf_send_packet() failed at line ",
__LINE__,
" in demuxer.h\n");
TORCH_CHECK(
0 <= av_bsf_receive_packet(bsfCtx, &pktFiltered),
"av_bsf_receive_packet() failed at line ",
__LINE__,
" in demuxer.h\n");
*video = pktFiltered.data;
*videoBytes = pktFiltered.size;
} else {
if (bMp4MPEG4 && (frameCount == 0)) {
int extraDataSize =
fmtCtx->streams[iVideoStream]->codecpar->extradata_size;
if (extraDataSize > 0) {
dataWithHeader = (uint8_t*)av_malloc(
extraDataSize + pkt.size - 3 * sizeof(uint8_t));
if (!dataWithHeader) {
TORCH_CHECK(
false,
"av_malloc() failed at line ",
__LINE__,
" in demuxer.h\n");
}
memcpy(
dataWithHeader,
fmtCtx->streams[iVideoStream]->codecpar->extradata,
extraDataSize);
memcpy(
dataWithHeader + extraDataSize,
pkt.data + 3,
pkt.size - 3 * sizeof(uint8_t));
*video = dataWithHeader;
*videoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t);
}
} else {
*video = pkt.data;
*videoBytes = pkt.size;
}
}
frameCount++;
return true;
}
void seek(double timestamp, int flag) {
int64_t time = timestamp * AV_TIME_BASE;
TORCH_CHECK(
0 <= av_seek_frame(fmtCtx, -1, time, flag),
"av_seek_frame() failed at line ",
__LINE__,
" in demuxer.h\n");
}
};
inline cudaVideoCodec ffmpeg_to_codec(AVCodecID id) {
switch (id) {
case AV_CODEC_ID_MPEG1VIDEO:
return cudaVideoCodec_MPEG1;
case AV_CODEC_ID_MPEG2VIDEO:
return cudaVideoCodec_MPEG2;
case AV_CODEC_ID_MPEG4:
return cudaVideoCodec_MPEG4;
case AV_CODEC_ID_WMV3:
case AV_CODEC_ID_VC1:
return cudaVideoCodec_VC1;
case AV_CODEC_ID_H264:
return cudaVideoCodec_H264;
case AV_CODEC_ID_HEVC:
return cudaVideoCodec_HEVC;
case AV_CODEC_ID_VP8:
return cudaVideoCodec_VP8;
case AV_CODEC_ID_VP9:
return cudaVideoCodec_VP9;
case AV_CODEC_ID_MJPEG:
return cudaVideoCodec_JPEG;
case AV_CODEC_ID_AV1:
return cudaVideoCodec_AV1;
default:
return cudaVideoCodec_NumCodecs;
}
}
#include "gpu_decoder.h"
#include <c10/cuda/CUDAGuard.h>
/* Set cuda device, create cuda context and initialise the demuxer and decoder.
*/
GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
: demuxer(src_file.c_str()) {
at::cuda::CUDAGuard device_guard(dev);
device = device_guard.current_device().index();
check_for_cuda_errors(
cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
initialised = true;
}
GPUDecoder::~GPUDecoder() {
at::cuda::CUDAGuard device_guard(device);
decoder.release();
if (initialised) {
check_for_cuda_errors(
cuDevicePrimaryCtxRelease(device), __LINE__, __FILE__);
}
}
/* Fetch a decoded frame tensor after demuxing and decoding.
*/
torch::Tensor GPUDecoder::decode() {
torch::Tensor frameTensor;
unsigned long videoBytes = 0;
uint8_t* video = nullptr;
at::cuda::CUDAGuard device_guard(device);
torch::Tensor frame;
do {
demuxer.demux(&video, &videoBytes);
decoder.decode(video, videoBytes);
frame = decoder.fetch_frame();
} while (frame.numel() == 0 && videoBytes > 0);
return frame;
}
/* Seek to a passed timestamp. The second argument controls whether to seek to a
* keyframe.
*/
void GPUDecoder::seek(double timestamp, bool keyframes_only) {
int flag = keyframes_only ? 0 : AVSEEK_FLAG_ANY;
demuxer.seek(timestamp, flag);
}
c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
get_metadata() const {
c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
c10::Dict<std::string, double> video_metadata;
video_metadata.insert("duration", demuxer.get_duration());
video_metadata.insert("fps", demuxer.get_fps());
metadata.insert("video", video_metadata);
return metadata;
}
TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, torch::Device>())
.def("seek", &GPUDecoder::seek)
.def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode);
}
#include <torch/custom_class.h>
#include <torch/torch.h>
#include "decoder.h"
#include "demuxer.h"
class GPUDecoder : public torch::CustomClassHolder {
public:
GPUDecoder(std::string, torch::Device);
~GPUDecoder();
torch::Tensor decode();
void seek(double, bool);
c10::Dict<std::string, c10::Dict<std::string, double>> get_metadata() const;
private:
Demuxer demuxer;
CUcontext ctx;
Decoder decoder;
int64_t device;
bool initialised = false;
};
......@@ -61,7 +61,7 @@ DecoderInCallback MemoryBuffer::getCallback(
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - supported
// seek capability, yes - supported
return 0;
}
return object.seek(size, whence);
......
#include "stream.h"
#include <c10/util/Logging.h>
#include <stdio.h>
#include <string.h>
#include "util.h"
namespace ffmpeg {
......@@ -24,11 +26,16 @@ Stream::~Stream() {
}
}
// look up the proper CODEC querying the function
AVCodec* Stream::findCodec(AVCodecParameters* params) {
return avcodec_find_decoder(params->codec_id);
return (AVCodec*)avcodec_find_decoder(params->codec_id);
}
int Stream::openCodec(std::vector<DecoderMetadata>* metadata) {
// Allocate memory for the AVCodecContext, which will hold the context for
// decode/encode process. Then fill this codec context with CODEC parameters
// defined in stream parameters. Open the codec, and allocate the global frame
// defined in the header file
int Stream::openCodec(std::vector<DecoderMetadata>* metadata, int num_threads) {
AVStream* steam = inputCtx_->streams[format_.stream];
AVCodec* codec = findCodec(steam->codecpar);
......@@ -44,6 +51,21 @@ int Stream::openCodec(std::vector<DecoderMetadata>* metadata) {
<< ", avcodec_alloc_context3 failed";
return AVERROR(ENOMEM);
}
// multithreading heuristics
// if user defined,
if (num_threads > max_threads) {
num_threads = max_threads;
}
if (num_threads > 0) {
// if user defined, respect that
// note that default thread_type will be used
codecCtx_->thread_count = num_threads;
} else {
// otherwise set sensible defaults
codecCtx_->thread_count = 8;
codecCtx_->thread_type = FF_THREAD_SLICE;
}
int ret;
// Copy codec parameters from input stream to output codec context
......@@ -93,6 +115,9 @@ int Stream::openCodec(std::vector<DecoderMetadata>* metadata) {
return ret;
}
// send the raw data packet (compressed frame) to the decoder, through the codec
// context and receive the raw data frame (uncompressed frame) from the
// decoder, through the same codec context
int Stream::analyzePacket(const AVPacket* packet, bool* gotFrame) {
int consumed = 0;
int result = avcodec_send_packet(codecCtx_, packet);
......@@ -134,6 +159,9 @@ int Stream::analyzePacket(const AVPacket* packet, bool* gotFrame) {
return consumed;
}
// General decoding function:
// given the packet, analyse the metadata, and write the
// metadata and the buffer to the DecoderOutputImage.
int Stream::decodePacket(
const AVPacket* packet,
DecoderOutputMessage* out,
......@@ -167,6 +195,9 @@ int Stream::flush(DecoderOutputMessage* out, bool headerOnly) {
return 1;
}
// Sets the header and payload via stream::setHeader and copyFrameBytes
// functions that are defined in type stream subclass (VideoStream, AudioStream,
// ...)
int Stream::getMessage(DecoderOutputMessage* out, bool flush, bool headerOnly) {
if (flush) {
// only flush of audio frames makes sense
......
......@@ -20,7 +20,9 @@ class Stream {
virtual ~Stream();
// returns 0 - on success or negative error
int openCodec(std::vector<DecoderMetadata>* metadata);
// num_threads sets up the codec context for multithreading if needed
// default is set to single thread in order to not break BC
int openCodec(std::vector<DecoderMetadata>* metadata, int num_threads = 1);
// returns 1 - if packet got consumed, 0 - if it's not, and < 0 on error
int decodePacket(
const AVPacket* packet,
......@@ -69,6 +71,10 @@ class Stream {
// estimated next frame pts for flushing the last frame
int64_t nextPts_{0};
double fps_{30.};
// this is a dumb conservative limit; ideally we'd use
// int max_threads = at::get_num_threads(); but this would cause
// fb sync to fail as it would add dependency to ATen to the decoder API
const int max_threads = 12;
};
} // namespace ffmpeg
......@@ -43,21 +43,34 @@ int SubtitleStream::initFormat() {
int SubtitleStream::analyzePacket(const AVPacket* packet, bool* gotFrame) {
// clean-up
releaseSubtitle();
// FIXME: should this even be created?
AVPacket* avPacket;
avPacket = av_packet_alloc();
if (avPacket == nullptr) {
LOG(ERROR)
<< "decoder as not able to allocate the subtitle-specific packet.";
// alternative to ENOMEM
return AVERROR_BUFFER_TOO_SMALL;
}
avPacket->data = nullptr;
avPacket->size = 0;
// check flush packet
AVPacket avPacket;
av_init_packet(&avPacket);
avPacket.data = nullptr;
avPacket.size = 0;
auto pkt = packet ? *packet : avPacket;
auto pkt = packet ? packet : avPacket;
int gotFramePtr = 0;
int result = avcodec_decode_subtitle2(codecCtx_, &sub_, &gotFramePtr, &pkt);
// is these a better way than cast from const?
int result =
avcodec_decode_subtitle2(codecCtx_, &sub_, &gotFramePtr, (AVPacket*)pkt);
if (result < 0) {
LOG(ERROR) << "avcodec_decode_subtitle2 failed, err: "
<< Util::generateErrorDesc(result);
// free the packet we've created
av_packet_free(&avPacket);
return result;
} else if (result == 0) {
result = pkt.size; // discard the rest of the package
result = pkt->size; // discard the rest of the package
}
sub_.release = gotFramePtr;
......@@ -66,9 +79,10 @@ int SubtitleStream::analyzePacket(const AVPacket* packet, bool* gotFrame) {
// set proper pts in us
if (gotFramePtr) {
sub_.pts = av_rescale_q(
pkt.pts, inputCtx_->streams[format_.stream]->time_base, timeBaseQ);
pkt->pts, inputCtx_->streams[format_.stream]->time_base, timeBaseQ);
}
av_packet_free(&avPacket);
return result;
}
......
......@@ -19,17 +19,17 @@ void SyncDecoder::AVByteStorage::ensure(size_t n) {
}
uint8_t* SyncDecoder::AVByteStorage::writableTail() {
CHECK_LE(offset_ + length_, capacity_);
TORCH_CHECK_LE(offset_ + length_, capacity_);
return buffer_ + offset_ + length_;
}
void SyncDecoder::AVByteStorage::append(size_t n) {
CHECK_LE(n, tail());
TORCH_CHECK_LE(n, tail());
length_ += n;
}
void SyncDecoder::AVByteStorage::trim(size_t n) {
CHECK_LE(n, length_);
TORCH_CHECK_LE(n, length_);
offset_ += n;
length_ -= n;
}
......@@ -43,7 +43,7 @@ size_t SyncDecoder::AVByteStorage::length() const {
}
size_t SyncDecoder::AVByteStorage::tail() const {
CHECK_LE(offset_ + length_, capacity_);
TORCH_CHECK_LE(offset_ + length_, capacity_);
return capacity_ - offset_ - length_;
}
......
......@@ -50,7 +50,8 @@ void gotFilesStats(std::vector<VideoFileStats>& stats) {
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
size_t s = fread(buffer.data(), 1, buffer.size(), f);
TORCH_CHECK_EQ(buffer.size(), s);
fclose(f);
for (size_t i = 0; i < rounds; ++i) {
......@@ -66,7 +67,7 @@ void gotFilesStats(std::vector<VideoFileStats>& stats) {
avgProvUs +=
std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count();
CHECK_EQ(metadata.size(), 1);
TORCH_CHECK_EQ(metadata.size(), 1);
item.num = metadata[0].num;
item.den = metadata[0].den;
item.fps = metadata[0].fps;
......@@ -90,7 +91,8 @@ size_t measurePerformanceUs(
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
size_t s = fread(buffer.data(), 1, buffer.size(), f);
TORCH_CHECK_EQ(buffer.size(), s);
fclose(f);
for (size_t i = 0; i < rounds; ++i) {
......@@ -324,7 +326,8 @@ TEST(SyncDecoder, TestMemoryBuffer) {
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
size_t s = fread(buffer.data(), 1, buffer.size(), f);
TORCH_CHECK_EQ(buffer.size(), s);
fclose(f);
CHECK(decoder.init(
params,
......@@ -349,7 +352,8 @@ TEST(SyncDecoder, TestMemoryBufferNoSeekableWithFullRead) {
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
size_t s = fread(buffer.data(), 1, buffer.size(), f);
TORCH_CHECK_EQ(buffer.size(), s);
fclose(f);
params.maxSeekableBytes = buffer.size() + 1;
......@@ -364,7 +368,7 @@ TEST(SyncDecoder, TestMemoryBufferNoSeekableWithFullRead) {
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - no
// seek capability, yes - no
return -1;
}
return object.seek(size, whence);
......@@ -388,7 +392,8 @@ TEST(SyncDecoder, TestMemoryBufferNoSeekableWithPartialRead) {
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
size_t s = fread(buffer.data(), 1, buffer.size(), f);
TORCH_CHECK_EQ(buffer.size(), s);
fclose(f);
params.maxSeekableBytes = buffer.size() / 2;
......@@ -403,7 +408,7 @@ TEST(SyncDecoder, TestMemoryBufferNoSeekableWithPartialRead) {
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - no
// seek capability, yes - no
return -1;
}
return object.seek(size, whence);
......
......@@ -265,7 +265,7 @@ std::string generateErrorDesc(int errorCode) {
size_t serialize(const AVSubtitle& sub, ByteStorage* out) {
const auto len = size(sub);
CHECK_LE(len, out->tail());
TORCH_CHECK_LE(len, out->tail());
size_t pos = 0;
if (!Serializer::serializeItem(out->writableTail(), len, pos, sub)) {
return 0;
......
......@@ -7,6 +7,17 @@
namespace ffmpeg {
namespace {
// Setup the data pointers and linesizes based on the specified image
// parameters and the provided array. This sets up "planes" to point to a
// "buffer"
// NOTE: this is most likely culprit behind #3534
//
// Args:
// fmt: desired output video format
// buffer: source constant image buffer (in different format) that will contain
// the final image after SWScale planes: destination data pointer to be filled
// lineSize: target destination linesize (always {0})
int preparePlanes(
const VideoFormat& fmt,
const uint8_t* buffer,
......@@ -14,6 +25,7 @@ int preparePlanes(
int* lineSize) {
int result;
// NOTE: 1 at the end of av_fill_arrays is the value used for alignment
if ((result = av_image_fill_arrays(
planes,
lineSize,
......@@ -28,6 +40,18 @@ int preparePlanes(
return result;
}
// Scale (and crop) the image slice in srcSlice and put the resulting scaled
// slice to `planes` buffer, which is mapped to be `out` via preparePlanes as
// `sws_scale` cannot access buffers directly.
//
// Args:
// context: SWSContext allocated on line 119 (if crop, optional) or 163 (if
// scale) srcSlice: frame data in YUV420P srcStride: the array containing the
// strides for each plane of the source
// image (from AVFrame->linesize[0])
// out: destination buffer
// planes: indirect destination buffer (mapped to "out" via preparePlanes)
// lines: destination linesize; constant {0}
int transformImage(
SwsContext* context,
const uint8_t* const srcSlice[],
......@@ -41,13 +65,32 @@ int transformImage(
if ((result = preparePlanes(outFormat, out, planes, lines)) < 0) {
return result;
}
if (context) {
// NOTE: srcY stride always 0: this is a parameter of YUV format
if ((result = sws_scale(
context, srcSlice, srcStride, 0, inFormat.height, planes, lines)) <
0) {
LOG(ERROR) << "sws_scale failed, err: " << Util::generateErrorDesc(result);
LOG(ERROR) << "sws_scale failed, err: "
<< Util::generateErrorDesc(result);
return result;
}
} else if (
inFormat.width == outFormat.width &&
inFormat.height == outFormat.height &&
inFormat.format == outFormat.format) {
// Copy planes without using sws_scale if sws_getContext failed.
av_image_copy(
planes,
lines,
(const uint8_t**)srcSlice,
srcStride,
(AVPixelFormat)inFormat.format,
inFormat.width,
inFormat.height);
} else {
LOG(ERROR) << "Invalid scale context format " << inFormat.format;
return AVERROR(EINVAL);
}
return 0;
}
} // namespace
......@@ -135,6 +178,26 @@ bool VideoSampler::init(const SamplerParameters& params) {
<< params.out.video.minDimension << ", cropImage "
<< params.out.video.cropImage;
// set output format
params_ = params;
if (params.in.video.format == AV_PIX_FMT_YUV420P) {
/* When the video width and height are not multiples of 8,
* and there is no size change in the conversion,
* a blurry screen will appear on the right side
* This problem was discovered in 2012 and
* continues to exist in version 4.1.3 in 2019
* This problem can be avoided by increasing SWS_ACCURATE_RND
* details https://trac.ffmpeg.org/ticket/1582
*/
if ((params.in.video.width & 0x7) || (params.in.video.height & 0x7)) {
VLOG(1) << "The width " << params.in.video.width << " and height "
<< params.in.video.height << " the image is not a multiple of 8, "
<< "the decoding speed may be reduced";
swsFlags_ |= SWS_ACCURATE_RND;
}
}
scaleContext_ = sws_getContext(
params.in.video.width,
params.in.video.height,
......@@ -146,13 +209,24 @@ bool VideoSampler::init(const SamplerParameters& params) {
nullptr,
nullptr,
nullptr);
// set output format
params_ = params;
// sws_getContext might fail if in/out format == AV_PIX_FMT_PAL8 (png format)
// Return true if input and output formats/width/height are identical
// Check scaleContext_ for nullptr in transformImage to copy planes directly
if (params.in.video.width == scaleFormat_.width &&
params.in.video.height == scaleFormat_.height &&
params.in.video.format == scaleFormat_.format) {
return true;
}
return scaleContext_ != nullptr;
}
// Main body of the sample function called from one of the overloads below
//
// Args:
// srcSlice: decoded AVFrame->data perpared buffer
// srcStride: linesize (usually obtained from AVFrame->linesize)
// out: return buffer (ByteStorage*)
int VideoSampler::sample(
const uint8_t* const srcSlice[],
int srcStride[],
......@@ -221,6 +295,7 @@ int VideoSampler::sample(
return outImageSize;
}
// Call from `video_stream.cpp::114` - occurs during file reads
int VideoSampler::sample(AVFrame* frame, ByteStorage* out) {
if (!frame) {
return 0; // no flush for videos
......@@ -229,6 +304,7 @@ int VideoSampler::sample(AVFrame* frame, ByteStorage* out) {
return sample(frame->data, frame->linesize, out);
}
// Call from `video_stream.cpp::114` - not sure when this occurs
int VideoSampler::sample(const ByteStorage* in, ByteStorage* out) {
if (!in) {
return 0; // no flush for videos
......
......@@ -6,11 +6,13 @@ namespace ffmpeg {
namespace {
bool operator==(const VideoFormat& x, const AVFrame& y) {
return x.width == y.width && x.height == y.height && x.format == y.format;
return x.width == static_cast<size_t>(y.width) &&
x.height == static_cast<size_t>(y.height) && x.format == y.format;
}
bool operator==(const VideoFormat& x, const AVCodecContext& y) {
return x.width == y.width && x.height == y.height && x.format == y.pix_fmt;
return x.width == static_cast<size_t>(y.width) &&
x.height == static_cast<size_t>(y.height) && x.format == y.pix_fmt;
}
VideoFormat& toVideoFormat(VideoFormat& x, const AVFrame& y) {
......@@ -80,6 +82,7 @@ int VideoStream::initFormat() {
: -1;
}
// copies frame bytes via sws_scale call in video_sampler.cpp
int VideoStream::copyFrameBytes(ByteStorage* out, bool flush) {
if (!sampler_) {
sampler_ = std::make_unique<VideoSampler>(SWS_AREA, loggingUuid_);
......@@ -110,7 +113,9 @@ int VideoStream::copyFrameBytes(ByteStorage* out, bool flush) {
<< ", minDimension: " << format_.format.video.minDimension
<< ", crop: " << format_.format.video.cropImage;
}
// calls to a sampler that converts the frame from YUV422 to RGB24, and
// optionally crops and resizes the frame. Frame bytes are copied from
// frame_->data to out buffer
return sampler_->sample(flush ? nullptr : frame_, out);
}
......
#include "decode_gif.h"
#include <cstring>
#include "giflib/gif_lib.h"
namespace vision {
namespace image {
typedef struct reader_helper_t {
uint8_t const* encoded_data; // input tensor data pointer
size_t encoded_data_size; // size of input tensor in bytes
size_t num_bytes_read; // number of bytes read so far in the tensor
} reader_helper_t;
// That function is used by GIFLIB routines to read the encoded bytes.
// This reads `len` bytes and writes them into `buf`. The data is read from the
// input tensor passed to decode_gif() starting at the `num_bytes_read`
// position.
int read_from_tensor(GifFileType* gifFile, GifByteType* buf, int len) {
// the UserData field was set in DGifOpen()
reader_helper_t* reader_helper =
static_cast<reader_helper_t*>(gifFile->UserData);
size_t num_bytes_to_read = std::min(
(size_t)len,
reader_helper->encoded_data_size - reader_helper->num_bytes_read);
std::memcpy(
buf, reader_helper->encoded_data + reader_helper->num_bytes_read, len);
reader_helper->num_bytes_read += num_bytes_to_read;
return num_bytes_to_read;
}
torch::Tensor decode_gif(const torch::Tensor& encoded_data) {
// LibGif docs: https://giflib.sourceforge.net/intro.html
// Refer over there for more details on the libgif API, API ref, and a
// detailed description of the GIF format.
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");
int error = D_GIF_SUCCEEDED;
// We're using DGidOpen. The other entrypoints of libgif are
// DGifOpenFileName and DGifOpenFileHandle but we don't want to use those,
// since we need to read the encoded bytes from a tensor of encoded bytes, not
// from a file (for consistency with existing jpeg and png decoders). Using
// DGifOpen is the only way to read from a custom source.
// For that we need to provide a reader function `read_from_tensor` that
// reads from the tensor, and we have to keep track of the number of bytes
// read so far: this is why we need the reader_helper struct.
// TODO: We are potentially doing an unnecessary copy of the encoded bytes:
// - 1 copy in from file to tensor (in read_file())
// - 1 copy from tensor to GIFLIB buffers (in read_from_tensor())
// Since we're vendoring GIFLIB we can potentially modify the calls to
// InternalRead() and just set the `buf` pointer to the tensor data directly.
// That might even save allocation of those buffers.
// If we do that, we'd have to make sure the buffers are never written to by
// GIFLIB, otherwise we'd be overridding the tensor data.
reader_helper_t reader_helper;
reader_helper.encoded_data = encoded_data.data_ptr<uint8_t>();
reader_helper.encoded_data_size = encoded_data.numel();
reader_helper.num_bytes_read = 0;
GifFileType* gifFile =
DGifOpen(static_cast<void*>(&reader_helper), read_from_tensor, &error);
TORCH_CHECK(
(gifFile != nullptr) && (error == D_GIF_SUCCEEDED),
"DGifOpenFileName() failed - ",
error);
if (DGifSlurp(gifFile) == GIF_ERROR) {
auto gifFileError = gifFile->Error;
DGifCloseFile(gifFile, &error);
TORCH_CHECK(false, "DGifSlurp() failed - ", gifFileError);
}
auto num_images = gifFile->ImageCount;
// This check should already done within DGifSlurp(), just to be safe
TORCH_CHECK(num_images > 0, "GIF file should contain at least one image!");
GifColorType bg = {0, 0, 0};
if (gifFile->SColorMap) {
bg = gifFile->SColorMap->Colors[gifFile->SBackGroundColor];
}
// The GIFLIB docs say that the canvas's height and width are potentially
// ignored by modern viewers, so to be on the safe side we set the output
// height to max(canvas_heigh, first_image_height). Same for width.
// https://giflib.sourceforge.net/whatsinagif/bits_and_bytes.html
auto out_h =
std::max(gifFile->SHeight, gifFile->SavedImages[0].ImageDesc.Height);
auto out_w =
std::max(gifFile->SWidth, gifFile->SavedImages[0].ImageDesc.Width);
// We output a channels-last tensor for consistency with other image decoders.
// Torchvision's resize tends to be is faster on uint8 channels-last tensors.
auto options = torch::TensorOptions()
.dtype(torch::kU8)
.memory_format(torch::MemoryFormat::ChannelsLast);
auto out = torch::empty(
{int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}, options);
auto out_a = out.accessor<uint8_t, 4>();
for (int i = 0; i < num_images; i++) {
const SavedImage& img = gifFile->SavedImages[i];
GraphicsControlBlock gcb;
DGifSavedExtensionToGCB(gifFile, i, &gcb);
const GifImageDesc& desc = img.ImageDesc;
const ColorMapObject* cmap =
desc.ColorMap ? desc.ColorMap : gifFile->SColorMap;
TORCH_CHECK(
cmap != nullptr,
"Global and local color maps are missing. This should never happen!");
// When going from one image to another, there is a "disposal method" which
// specifies how to handle the transition. E.g. DISPOSE_DO_NOT means that
// the current image should essentially be drawn on top of the previous
// canvas. The pixels of that previous canvas will appear on the new one if
// either:
// - a pixel is transparent in the current image
// - the current image is smaller than the canvas, hence exposing its pixels
// The "background" disposal method means that the current canvas should be
// set to the background color.
// We only support these 2 modes and default to "background" when the
// disposal method is unspecified, or when it's set to "DISPOSE_PREVIOUS"
// which according to GIFLIB is not widely supported.
// (https://giflib.sourceforge.net/whatsinagif/animation_and_transparency.html).
if (i > 0 && gcb.DisposalMode == DISPOSE_DO_NOT) {
out[i] = out[i - 1];
} else {
// Background. If bg wasn't defined, it will be (0, 0, 0)
for (int h = 0; h < gifFile->SHeight; h++) {
for (int w = 0; w < gifFile->SWidth; w++) {
out_a[i][0][h][w] = bg.Red;
out_a[i][1][h][w] = bg.Green;
out_a[i][2][h][w] = bg.Blue;
}
}
}
for (int h = 0; h < desc.Height; h++) {
for (int w = 0; w < desc.Width; w++) {
auto c = img.RasterBits[h * desc.Width + w];
if (c == gcb.TransparentColor) {
continue;
}
GifColorType rgb = cmap->Colors[c];
out_a[i][0][h + desc.Top][w + desc.Left] = rgb.Red;
out_a[i][1][h + desc.Top][w + desc.Left] = rgb.Green;
out_a[i][2][h + desc.Top][w + desc.Left] = rgb.Blue;
}
}
}
out = out.squeeze(0); // remove batch dim if there's only one image
DGifCloseFile(gifFile, &error);
TORCH_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error);
return out;
}
} // namespace image
} // namespace vision
#pragma once
#include <torch/types.h>
namespace vision {
namespace image {
// encoded_data tensor must be 1D uint8 and contiguous
C10_EXPORT torch::Tensor decode_gif(const torch::Tensor& encoded_data);
} // namespace image
} // namespace vision
#include "decode_image.h"
#include "decode_gif.h"
#include "decode_jpeg.h"
#include "decode_png.h"
namespace vision {
namespace image {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
// Check that tensor is a CPU tensor
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......@@ -18,15 +24,24 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
const uint8_t gif_signature_1[6] = {
0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a"
const uint8_t gif_signature_2[6] = {
0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a"
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode);
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(data, mode);
return decode_png(
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
} else if (
memcmp(gif_signature_1, datap, 6) == 0 ||
memcmp(gif_signature_2, datap, 6) == 0) {
return decode_gif(data);
} else {
TORCH_CHECK(
false,
"Unsupported image file. Only jpeg and png ",
"Unsupported image file. Only jpeg, png and gif ",
"are currently supported.");
}
}
......
......@@ -8,7 +8,8 @@ namespace image {
C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);
} // namespace image
} // namespace vision
#include "decode_jpeg.h"
#include "common_jpeg.h"
#include "exif.h"
namespace vision {
namespace image {
#if !JPEG_FOUND
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
TORCH_CHECK(
false, "decode_jpeg: torchvision not compiled with libjpeg support");
}
#else
using namespace detail;
using namespace exif_private;
namespace {
......@@ -65,11 +70,70 @@ static void torch_jpeg_set_source_mgr(
src->len = len;
src->pub.bytes_in_buffer = len;
src->pub.next_input_byte = src->data;
jpeg_save_markers(cinfo, APP1, 0xffff);
}
inline unsigned char clamped_cmyk_rgb_convert(
unsigned char k,
unsigned char cmy) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
int v = k * cmy + 128;
v = ((v >> 8) + v) >> 8;
return std::clamp(k - v, 0, 255);
}
void convert_line_cmyk_to_rgb(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* rgb_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];
rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
}
}
inline unsigned char rgb_to_gray(int r, int g, int b) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}
void convert_line_cmyk_to_gray(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* gray_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];
int r = clamped_cmyk_rgb_convert(k, 255 - c);
int g = clamped_cmyk_rgb_convert(k, 255 - m);
int b = clamped_cmyk_rgb_convert(k, 255 - y);
gray_line[i] = rgb_to_gray(r, g, b);
}
}
} // namespace
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......@@ -100,20 +164,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_read_header(&cinfo, TRUE);
int channels = cinfo.num_components;
bool cmyk_to_rgb_or_gray = false;
if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
channels = 1;
break;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
channels = 3;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
......@@ -128,6 +201,11 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_calc_output_dimensions(&cinfo);
}
int exif_orientation = -1;
if (apply_exif_orientation) {
exif_orientation = fetch_jpeg_exif_orientation(&cinfo);
}
jpeg_start_decompress(&cinfo);
int height = cinfo.output_height;
......@@ -137,21 +215,57 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
torch::Tensor cmyk_line_tensor;
if (cmyk_to_rgb_or_gray) {
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
}
while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
if (cmyk_to_rgb_or_gray) {
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
if (channels == 3) {
convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
} else if (channels == 1) {
convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
}
} else {
jpeg_read_scanlines(&cinfo, &ptr, 1);
}
ptr += stride;
}
jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1});
auto output = tensor.permute({2, 0, 1});
if (apply_exif_orientation) {
return exif_orientation_transform(output, exif_orientation);
}
return output;
}
#endif // #if !JPEG_FOUND
int64_t _jpeg_version() {
#if JPEG_FOUND
return JPEG_LIB_VERSION;
#else
return -1;
#endif
}
bool _is_compiled_against_turbo() {
#ifdef LIBJPEG_TURBO_VERSION
return true;
#else
return false;
#endif
}
} // namespace image
} // namespace vision
......@@ -8,7 +8,11 @@ namespace image {
C10_EXPORT torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);
C10_EXPORT int64_t _jpeg_version();
C10_EXPORT bool _is_compiled_against_turbo();
} // namespace image
} // namespace vision
#include "decode_png.h"
#include "common_png.h"
#include "exif.h"
namespace vision {
namespace image {
using namespace exif_private;
#if !PNG_FOUND
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits,
bool apply_exif_orientation) {
TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
}
#else
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
bool is_little_endian() {
uint32_t x = 1;
return *(uint8_t*)&x;
}
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......@@ -29,25 +46,35 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
}
auto datap = data.accessor<unsigned char, 1>().data();
auto accessor = data.accessor<unsigned char, 1>();
auto datap = accessor.data();
auto datap_len = accessor.size(0);
if (setjmp(png_jmpbuf(png_ptr)) != 0) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Internal error.");
}
TORCH_CHECK(datap_len >= 8, "Content is too small for png!")
auto is_png = !png_sig_cmp(datap, 0, 8);
TORCH_CHECK(is_png, "Content is not png!")
struct Reader {
png_const_bytep ptr;
png_size_t count;
} reader;
reader.ptr = png_const_bytep(datap) + 8;
reader.count = datap_len - 8;
auto read_callback =
[](png_structp png_ptr, png_bytep output, png_size_t bytes) {
auto read_callback = [](png_structp png_ptr,
png_bytep output,
png_size_t bytes) {
auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
TORCH_CHECK(
reader->count >= bytes,
"Out of bound read in decode_png. Probably, the input image is corrupted");
std::copy(reader->ptr, reader->ptr + bytes, output);
reader->ptr += bytes;
reader->count -= bytes;
};
png_set_sig_bytes(png_ptr, 8);
png_set_read_fn(png_ptr, &reader, read_callback);
......@@ -55,6 +82,7 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
png_uint_32 width, height;
int bit_depth, color_type;
int interlace_type;
auto retval = png_get_IHDR(
png_ptr,
info_ptr,
......@@ -62,7 +90,7 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
&height,
&bit_depth,
&color_type,
nullptr,
&interlace_type,
nullptr,
nullptr);
......@@ -71,8 +99,26 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
auto max_bit_depth = allow_16_bits ? 16 : 8;
auto err_msg = "At most " + std::to_string(max_bit_depth) +
"-bit PNG images are supported currently.";
if (bit_depth > max_bit_depth) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, err_msg)
}
int channels = png_get_channels(png_ptr, info_ptr);
if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
png_set_expand_gray_1_2_4_to_8(png_ptr);
int number_of_passes;
if (interlace_type == PNG_INTERLACE_ADAM7) {
number_of_passes = png_set_interlace_handling(png_ptr);
} else {
number_of_passes = 1;
}
if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
......@@ -152,16 +198,60 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
png_read_update_info(png_ptr, info_ptr);
}
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
auto num_pixels_per_row = width * channels;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), channels},
bit_depth <= 8 ? torch::kU8 : torch::kI32);
if (bit_depth <= 8) {
auto t_ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<uint8_t, 3>().data();
}
} else {
// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
png_set_swap(png_ptr);
}
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();
// We create a tensor instead of malloc-ing for automatic memory management
auto tmp_buffer_tensor = torch::empty(
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
uint16_t* tmp_buffer =
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
// Now we copy the uint16 values into the int32 tensor.
for (size_t j = 0; j < num_pixels_per_row; ++j) {
t_ptr[j] = (int32_t)tmp_buffer[j];
}
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<int32_t, 3>().data();
}
}
int exif_orientation = -1;
if (apply_exif_orientation) {
exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr);
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
auto output = tensor.permute({2, 0, 1});
if (apply_exif_orientation) {
return exif_orientation_transform(output, exif_orientation);
}
return output;
}
#endif
......
......@@ -8,7 +8,9 @@ namespace image {
C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool allow_16_bits = false,
bool apply_exif_orientation = false);
} // namespace image
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment