Commit 0fc002df authored by huchen's avatar huchen
Browse files

init the dlexamples new

parent 0e04b692
#pragma once
#include <list>
#include "decoder.h"
namespace ffmpeg {
/**
* Class uses FFMPEG library to decode media streams.
* Media bytes can be explicitly provided through read-callback
* or fetched internally by FFMPEG library
*/
class SyncDecoder : public Decoder {
public:
// Allocation of memory must be done with a proper alignment.
class AVByteStorage : public ByteStorage {
public:
explicit AVByteStorage(size_t n);
~AVByteStorage() override;
void ensure(size_t n) override;
uint8_t* writableTail() override;
void append(size_t n) override;
void trim(size_t n) override;
const uint8_t* data() const override;
size_t length() const override;
size_t tail() const override;
void clear() override;
private:
size_t offset_{0};
size_t length_{0};
size_t capacity_{0};
uint8_t* buffer_{nullptr};
};
public:
int decode(DecoderOutputMessage* out, uint64_t timeoutMs) override;
private:
void push(DecoderOutputMessage&& buffer) override;
void onInit() override;
std::unique_ptr<ByteStorage> createByteStorage(size_t n) override;
private:
std::list<DecoderOutputMessage> queue_;
bool eof_{false};
};
} // namespace ffmpeg
#include <c10/util/Logging.h>
#include <dirent.h>
#include <gtest/gtest.h>
#include "memory_buffer.h"
#include "sync_decoder.h"
#include "util.h"
using namespace ffmpeg;
namespace {
struct VideoFileStats {
std::string name;
size_t durationPts{0};
int num{0};
int den{0};
int fps{0};
};
void gotAllTestFiles(
const std::string& folder,
std::vector<VideoFileStats>* stats) {
DIR* d = opendir(folder.c_str());
CHECK(d);
struct dirent* dir;
while ((dir = readdir(d))) {
if (dir->d_type != DT_DIR && 0 != strcmp(dir->d_name, "README")) {
VideoFileStats item;
item.name = folder + '/' + dir->d_name;
LOG(INFO) << "Found video file: " << item.name;
stats->push_back(std::move(item));
}
}
closedir(d);
}
void gotFilesStats(std::vector<VideoFileStats>& stats) {
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.seekAccuracy = 100000;
params.formats = {MediaFormat(0)};
params.headerOnly = true;
params.preventStaleness = false;
size_t avgProvUs = 0;
const size_t rounds = 100;
for (auto& item : stats) {
LOG(INFO) << "Decoding video file in memory: " << item.name;
FILE* f = fopen(item.name.c_str(), "rb");
CHECK(f != nullptr);
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
fclose(f);
for (size_t i = 0; i < rounds; ++i) {
SyncDecoder decoder;
std::vector<DecoderMetadata> metadata;
const auto now = std::chrono::steady_clock::now();
CHECK(decoder.init(
params,
MemoryBuffer::getCallback(buffer.data(), buffer.size()),
&metadata));
const auto then = std::chrono::steady_clock::now();
decoder.shutdown();
avgProvUs +=
std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count();
CHECK_EQ(metadata.size(), 1);
item.num = metadata[0].num;
item.den = metadata[0].den;
item.fps = metadata[0].fps;
item.durationPts =
av_rescale_q(metadata[0].duration, AV_TIME_BASE_Q, {1, item.fps});
}
}
LOG(INFO) << "Probing (us) " << avgProvUs / stats.size() / rounds;
}
size_t measurePerformanceUs(
const std::vector<VideoFileStats>& stats,
size_t rounds,
size_t num,
size_t stride) {
size_t avgClipDecodingUs = 0;
std::srand(time(nullptr));
for (const auto& item : stats) {
FILE* f = fopen(item.name.c_str(), "rb");
CHECK(f != nullptr);
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
fclose(f);
for (size_t i = 0; i < rounds; ++i) {
// randomy select clip
size_t rOffset = std::rand();
size_t fOffset = rOffset % item.durationPts;
size_t clipFrames = num + (num - 1) * stride;
if (fOffset + clipFrames > item.durationPts) {
fOffset = item.durationPts - clipFrames;
}
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.seekAccuracy = 100000;
params.preventStaleness = false;
for (size_t n = 0; n < num; ++n) {
std::list<DecoderOutputMessage> msgs;
params.startOffset =
av_rescale_q(fOffset, {1, item.fps}, AV_TIME_BASE_Q);
params.endOffset = params.startOffset + 100;
auto now = std::chrono::steady_clock::now();
SyncDecoder decoder;
CHECK(decoder.init(
params,
MemoryBuffer::getCallback(buffer.data(), buffer.size()),
nullptr));
DecoderOutputMessage out;
while (0 == decoder.decode(&out, params.timeoutMs)) {
msgs.push_back(std::move(out));
}
decoder.shutdown();
const auto then = std::chrono::steady_clock::now();
fOffset += 1 + stride;
avgClipDecodingUs +=
std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count();
}
}
}
return avgClipDecodingUs / rounds / num / stats.size();
}
void runDecoder(SyncDecoder& decoder) {
DecoderOutputMessage out;
size_t audioFrames = 0, videoFrames = 0, totalBytes = 0;
while (0 == decoder.decode(&out, 10000)) {
if (out.header.format.type == TYPE_AUDIO) {
++audioFrames;
} else if (out.header.format.type == TYPE_VIDEO) {
++videoFrames;
} else if (out.header.format.type == TYPE_SUBTITLE && out.payload) {
// deserialize
LOG(INFO) << "Deserializing subtitle";
AVSubtitle sub;
memset(&sub, 0, sizeof(sub));
EXPECT_TRUE(Util::deserialize(*out.payload, &sub));
LOG(INFO) << "Found subtitles"
<< ", num rects: " << sub.num_rects;
for (int i = 0; i < sub.num_rects; ++i) {
std::string text = "picture";
if (sub.rects[i]->type == SUBTITLE_TEXT) {
text = sub.rects[i]->text;
} else if (sub.rects[i]->type == SUBTITLE_ASS) {
text = sub.rects[i]->ass;
}
LOG(INFO) << "Rect num: " << i << ", type:" << sub.rects[i]->type
<< ", text: " << text;
}
avsubtitle_free(&sub);
}
if (out.payload) {
totalBytes += out.payload->length();
}
}
LOG(INFO) << "Decoded audio frames: " << audioFrames
<< ", video frames: " << videoFrames
<< ", total bytes: " << totalBytes;
}
} // namespace
TEST(SyncDecoder, TestSyncDecoderPerformance) {
// Measure the average time of decoding per clip
// 1. list of the videos in testing directory
// 2. for each video got number of frames with timestamps
// 3. randomly select frame offset
// 4. adjust offset for number frames and strides,
// if it's out out upper boundary
// 5. repeat multiple times, measuring and accumulating decoding time
// per clip.
/*
1) 4 x 2
2) 8 x 8
3) 16 x 8
4) 32 x 4
*/
const std::string kFolder = "pytorch/vision/test/assets/videos";
std::vector<VideoFileStats> stats;
gotAllTestFiles(kFolder, &stats);
gotFilesStats(stats);
const size_t kRounds = 10;
auto new4x2 = measurePerformanceUs(stats, kRounds, 4, 2);
auto new8x8 = measurePerformanceUs(stats, kRounds, 8, 8);
auto new16x8 = measurePerformanceUs(stats, kRounds, 16, 8);
auto new32x4 = measurePerformanceUs(stats, kRounds, 32, 4);
LOG(INFO) << "Clip decoding (us)"
<< ", new(4x2): " << new4x2 << ", new(8x8): " << new8x8
<< ", new(16x8): " << new16x8 << ", new(32x4): " << new32x4;
}
TEST(SyncDecoder, Test) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.seekAccuracy = 100000;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
params.uri = "pytorch/vision/test/assets/videos/R6llTwEh07w.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
}
TEST(SyncDecoder, TestSubtitles) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
params.uri = "vue/synergy/data/robotsub.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
}
TEST(SyncDecoder, TestHeadersOnly) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.seekAccuracy = 100000;
params.headerOnly = true;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
params.uri = "pytorch/vision/test/assets/videos/R6llTwEh07w.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
params.uri = "pytorch/vision/test/assets/videos/SOX5yA1l24A.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
params.uri = "pytorch/vision/test/assets/videos/WUzgd7C1pWA.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
}
TEST(SyncDecoder, TestHeadersOnlyDownSampling) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.seekAccuracy = 100000;
params.headerOnly = true;
MediaFormat format;
format.type = TYPE_AUDIO;
format.format.audio.samples = 8000;
params.formats.insert(format);
format.type = TYPE_VIDEO;
format.format.video.width = 224;
format.format.video.height = 224;
params.formats.insert(format);
params.uri = "pytorch/vision/test/assets/videos/R6llTwEh07w.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
params.uri = "pytorch/vision/test/assets/videos/SOX5yA1l24A.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
params.uri = "pytorch/vision/test/assets/videos/WUzgd7C1pWA.mp4";
CHECK(decoder.init(params, nullptr, nullptr));
runDecoder(decoder);
decoder.shutdown();
}
TEST(SyncDecoder, TestInitOnlyNoShutdown) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.seekAccuracy = 100000;
params.headerOnly = false;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
params.uri = "pytorch/vision/test/assets/videos/R6llTwEh07w.mp4";
std::vector<DecoderMetadata> metadata;
CHECK(decoder.init(params, nullptr, &metadata));
}
TEST(SyncDecoder, TestMemoryBuffer) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.endOffset = 9000000;
params.seekAccuracy = 10000;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
FILE* f = fopen(
"pytorch/vision/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi",
"rb");
CHECK(f != nullptr);
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
fclose(f);
CHECK(decoder.init(
params,
MemoryBuffer::getCallback(buffer.data(), buffer.size()),
nullptr));
LOG(INFO) << "Decoding from memory bytes: " << buffer.size();
runDecoder(decoder);
decoder.shutdown();
}
TEST(SyncDecoder, TestMemoryBufferNoSeekableWithFullRead) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.endOffset = 9000000;
params.seekAccuracy = 10000;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
FILE* f = fopen("pytorch/vision/test/assets/videos/R6llTwEh07w.mp4", "rb");
CHECK(f != nullptr);
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
fclose(f);
params.maxSeekableBytes = buffer.size() + 1;
MemoryBuffer object(buffer.data(), buffer.size());
CHECK(decoder.init(
params,
[object](uint8_t* out, int size, int whence, uint64_t timeoutMs) mutable
-> int {
if (out) { // see defs.h file
// read mode
return object.read(out, size);
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - no
return -1;
}
return object.seek(size, whence);
},
nullptr));
runDecoder(decoder);
decoder.shutdown();
}
TEST(SyncDecoder, TestMemoryBufferNoSeekableWithPartialRead) {
SyncDecoder decoder;
DecoderParameters params;
params.timeoutMs = 10000;
params.startOffset = 1000000;
params.endOffset = 9000000;
params.seekAccuracy = 10000;
params.formats = {MediaFormat(), MediaFormat(0), MediaFormat('0')};
FILE* f = fopen("pytorch/vision/test/assets/videos/R6llTwEh07w.mp4", "rb");
CHECK(f != nullptr);
fseek(f, 0, SEEK_END);
std::vector<uint8_t> buffer(ftell(f));
rewind(f);
CHECK_EQ(buffer.size(), fread(buffer.data(), 1, buffer.size(), f));
fclose(f);
params.maxSeekableBytes = buffer.size() / 2;
MemoryBuffer object(buffer.data(), buffer.size());
CHECK(!decoder.init(
params,
[object](uint8_t* out, int size, int whence, uint64_t timeoutMs) mutable
-> int {
if (out) { // see defs.h file
// read mode
return object.read(out, size);
}
// seek mode
if (!timeoutMs) {
// seek capabilty, yes - no
return -1;
}
return object.seek(size, whence);
},
nullptr));
}
#include "time_keeper.h"
#include "defs.h"
namespace ffmpeg {
namespace {
const long kMaxTimeBaseDiference = 10;
}
long TimeKeeper::adjust(long& decoderTimestamp) {
const long now = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
if (startTime_ == 0) {
startTime_ = now;
}
if (streamTimestamp_ == 0) {
streamTimestamp_ = decoderTimestamp;
}
const auto runOut = startTime_ + decoderTimestamp - streamTimestamp_;
if (std::labs((now - runOut) / AV_TIME_BASE) > kMaxTimeBaseDiference) {
streamTimestamp_ = startTime_ - now + decoderTimestamp;
}
const auto sleepAdvised = runOut - now;
decoderTimestamp += startTime_ - streamTimestamp_;
return sleepAdvised > 0 ? sleepAdvised : 0;
}
} // namespace ffmpeg
#pragma once
#include <stdlib.h>
#include <chrono>
namespace ffmpeg {
/**
* Class keeps the track of the decoded timestamps (us) for media streams.
*/
class TimeKeeper {
public:
TimeKeeper() = default;
// adjust provided @timestamp to the corrected value
// return advised sleep time before next frame processing in (us)
long adjust(long& decoderTimestamp);
private:
long startTime_{0};
long streamTimestamp_{0};
};
} // namespace ffmpeg
#include "util.h"
#include <c10/util/Logging.h>
namespace ffmpeg {
namespace Serializer {
// fixed size types
template <typename T>
inline size_t getSize(const T& x) {
return sizeof(x);
}
template <typename T>
inline bool serializeItem(
uint8_t* dest,
size_t len,
size_t& pos,
const T& src) {
VLOG(6) << "Generic serializeItem";
const auto required = sizeof(src);
if (len < pos + required) {
return false;
}
memcpy(dest + pos, &src, required);
pos += required;
return true;
}
template <typename T>
inline bool deserializeItem(
const uint8_t* src,
size_t len,
size_t& pos,
T& dest) {
const auto required = sizeof(dest);
if (len < pos + required) {
return false;
}
memcpy(&dest, src + pos, required);
pos += required;
return true;
}
// AVSubtitleRect specialization
inline size_t getSize(const AVSubtitleRect& x) {
auto rectBytes = [](const AVSubtitleRect& y) -> size_t {
size_t s = 0;
switch (y.type) {
case SUBTITLE_BITMAP:
for (int i = 0; i < y.nb_colors; ++i) {
s += sizeof(y.pict.linesize[i]);
s += y.pict.linesize[i];
}
break;
case SUBTITLE_TEXT:
s += sizeof(size_t);
s += strlen(y.text);
break;
case SUBTITLE_ASS:
s += sizeof(size_t);
s += strlen(y.ass);
break;
default:
break;
}
return s;
};
return getSize(x.x) + getSize(x.y) + getSize(x.w) + getSize(x.h) +
getSize(x.nb_colors) + getSize(x.type) + getSize(x.flags) + rectBytes(x);
}
// AVSubtitle specialization
inline size_t getSize(const AVSubtitle& x) {
auto rectBytes = [](const AVSubtitle& y) -> size_t {
size_t s = getSize(y.num_rects);
for (unsigned i = 0; i < y.num_rects; ++i) {
s += getSize(*y.rects[i]);
}
return s;
};
return getSize(x.format) + getSize(x.start_display_time) +
getSize(x.end_display_time) + getSize(x.pts) + rectBytes(x);
}
inline bool serializeItem(
uint8_t* dest,
size_t len,
size_t& pos,
const AVSubtitleRect& src) {
auto rectSerialize =
[](uint8_t* d, size_t l, size_t& p, const AVSubtitleRect& x) -> size_t {
switch (x.type) {
case SUBTITLE_BITMAP:
for (int i = 0; i < x.nb_colors; ++i) {
if (!serializeItem(d, l, p, x.pict.linesize[i])) {
return false;
}
if (p + x.pict.linesize[i] > l) {
return false;
}
memcpy(d + p, x.pict.data[i], x.pict.linesize[i]);
p += x.pict.linesize[i];
}
return true;
case SUBTITLE_TEXT: {
const size_t s = strlen(x.text);
if (!serializeItem(d, l, p, s)) {
return false;
}
if (p + s > l) {
return false;
}
memcpy(d + p, x.text, s);
p += s;
return true;
}
case SUBTITLE_ASS: {
const size_t s = strlen(x.ass);
if (!serializeItem(d, l, p, s)) {
return false;
}
if (p + s > l) {
return false;
}
memcpy(d + p, x.ass, s);
p += s;
return true;
}
default:
return true;
}
};
return serializeItem(dest, len, pos, src.x) &&
serializeItem(dest, len, pos, src.y) &&
serializeItem(dest, len, pos, src.w) &&
serializeItem(dest, len, pos, src.h) &&
serializeItem(dest, len, pos, src.nb_colors) &&
serializeItem(dest, len, pos, src.type) &&
serializeItem(dest, len, pos, src.flags) &&
rectSerialize(dest, len, pos, src);
}
inline bool serializeItem(
uint8_t* dest,
size_t len,
size_t& pos,
const AVSubtitle& src) {
auto rectSerialize =
[](uint8_t* d, size_t l, size_t& p, const AVSubtitle& x) -> bool {
bool res = serializeItem(d, l, p, x.num_rects);
for (unsigned i = 0; res && i < x.num_rects; ++i) {
res = serializeItem(d, l, p, *(x.rects[i]));
}
return res;
};
VLOG(6) << "AVSubtitle serializeItem";
return serializeItem(dest, len, pos, src.format) &&
serializeItem(dest, len, pos, src.start_display_time) &&
serializeItem(dest, len, pos, src.end_display_time) &&
serializeItem(dest, len, pos, src.pts) &&
rectSerialize(dest, len, pos, src);
}
inline bool deserializeItem(
const uint8_t* src,
size_t len,
size_t& pos,
AVSubtitleRect& dest) {
auto rectDeserialize =
[](const uint8_t* y, size_t l, size_t& p, AVSubtitleRect& x) -> bool {
switch (x.type) {
case SUBTITLE_BITMAP:
for (int i = 0; i < x.nb_colors; ++i) {
if (!deserializeItem(y, l, p, x.pict.linesize[i])) {
return false;
}
if (p + x.pict.linesize[i] > l) {
return false;
}
x.pict.data[i] = (uint8_t*)av_malloc(x.pict.linesize[i]);
memcpy(x.pict.data[i], y + p, x.pict.linesize[i]);
p += x.pict.linesize[i];
}
return true;
case SUBTITLE_TEXT: {
size_t s = 0;
if (!deserializeItem(y, l, p, s)) {
return false;
}
if (p + s > l) {
return false;
}
x.text = (char*)av_malloc(s + 1);
memcpy(x.text, y + p, s);
x.text[s] = 0;
p += s;
return true;
}
case SUBTITLE_ASS: {
size_t s = 0;
if (!deserializeItem(y, l, p, s)) {
return false;
}
if (p + s > l) {
return false;
}
x.ass = (char*)av_malloc(s + 1);
memcpy(x.ass, y + p, s);
x.ass[s] = 0;
p += s;
return true;
}
default:
return true;
}
};
return deserializeItem(src, len, pos, dest.x) &&
deserializeItem(src, len, pos, dest.y) &&
deserializeItem(src, len, pos, dest.w) &&
deserializeItem(src, len, pos, dest.h) &&
deserializeItem(src, len, pos, dest.nb_colors) &&
deserializeItem(src, len, pos, dest.type) &&
deserializeItem(src, len, pos, dest.flags) &&
rectDeserialize(src, len, pos, dest);
}
inline bool deserializeItem(
const uint8_t* src,
size_t len,
size_t& pos,
AVSubtitle& dest) {
auto rectDeserialize =
[](const uint8_t* y, size_t l, size_t& p, AVSubtitle& x) -> bool {
bool res = deserializeItem(y, l, p, x.num_rects);
if (res && x.num_rects) {
x.rects =
(AVSubtitleRect**)av_malloc(x.num_rects * sizeof(AVSubtitleRect*));
}
for (unsigned i = 0; res && i < x.num_rects; ++i) {
x.rects[i] = (AVSubtitleRect*)av_malloc(sizeof(AVSubtitleRect));
memset(x.rects[i], 0, sizeof(AVSubtitleRect));
res = deserializeItem(y, l, p, *x.rects[i]);
}
return res;
};
return deserializeItem(src, len, pos, dest.format) &&
deserializeItem(src, len, pos, dest.start_display_time) &&
deserializeItem(src, len, pos, dest.end_display_time) &&
deserializeItem(src, len, pos, dest.pts) &&
rectDeserialize(src, len, pos, dest);
}
} // namespace Serializer
namespace Util {
std::string generateErrorDesc(int errorCode) {
std::array<char, 1024> buffer;
if (av_strerror(errorCode, buffer.data(), buffer.size()) < 0) {
return std::string("Unknown error code: ") + std::to_string(errorCode);
}
buffer.back() = 0;
return std::string(buffer.data());
}
size_t serialize(const AVSubtitle& sub, ByteStorage* out) {
const auto len = size(sub);
CHECK_LE(len, out->tail());
size_t pos = 0;
if (!Serializer::serializeItem(out->writableTail(), len, pos, sub)) {
return 0;
}
out->append(len);
return len;
}
bool deserialize(const ByteStorage& buf, AVSubtitle* sub) {
size_t pos = 0;
return Serializer::deserializeItem(buf.data(), buf.length(), pos, *sub);
}
size_t size(const AVSubtitle& sub) {
return Serializer::getSize(sub);
}
bool validateVideoFormat(const VideoFormat& f) {
// clang-format off
/*
Valid parameters values for decoder
____________________________________________________________________________________
| W | H | minDimension | maxDimension | cropImage | algorithm |
|__________________________________________________________________________________|
| 0 | 0 | 0 | 0 | N/A | original |
|__________________________________________________________________________________|
| >0 | 0 | N/A | N/A | N/A | scale keeping W |
|__________________________________________________________________________________|
| 0 | >0 | N/A | N/A | N/A | scale keeping H |
|__________________________________________________________________________________|
| >0 | >0 | N/A | N/A | 0 | stretch/scale |
|__________________________________________________________________________________|
| >0 | >0 | N/A | N/A | >0 | scale/crop |
|__________________________________________________________________________________|
| 0 | 0 | >0 | 0 | N/A |scale to min dimension |
|__________________________________________________________________________________|
| 0 | 0 | 0 | >0 | N/A |scale to max dimension |
|__________________________________________________________________________________|
| 0 | 0 | >0 | >0 | N/A |stretch to min/max dimension|
|_____|_____|______________|______________|___________|____________________________|
*/
// clang-format on
return (f.width == 0 && // #1, #6, #7 and #8
f.height == 0 && f.cropImage == 0) ||
(f.width != 0 && // #4 and #5
f.height != 0 && f.minDimension == 0 && f.maxDimension == 0) ||
(((f.width != 0 && // #2
f.height == 0) ||
(f.width == 0 && // #3
f.height != 0)) &&
f.minDimension == 0 && f.maxDimension == 0 && f.cropImage == 0);
}
void setFormatDimensions(
size_t& destW,
size_t& destH,
size_t userW,
size_t userH,
size_t srcW,
size_t srcH,
size_t minDimension,
size_t maxDimension,
size_t cropImage) {
// rounding rules
// int -> double -> round up
// if fraction is >= 0.5 or round down if fraction is < 0.5
// int result = double(value) + 0.5
// here we rounding double to int according to the above rule
// #1, #6, #7 and #8
if (userW == 0 && userH == 0) {
if (minDimension > 0 && maxDimension == 0) { // #6
if (srcW > srcH) {
// landscape
destH = minDimension;
destW = round(double(srcW * minDimension) / srcH);
} else {
// portrait
destW = minDimension;
destH = round(double(srcH * minDimension) / srcW);
}
} else if (minDimension == 0 && maxDimension > 0) { // #7
if (srcW > srcH) {
// landscape
destW = maxDimension;
destH = round(double(srcH * maxDimension) / srcW);
} else {
// portrait
destH = maxDimension;
destW = round(double(srcW * maxDimension) / srcH);
}
} else if (minDimension > 0 && maxDimension > 0) { // #8
if (srcW > srcH) {
// landscape
destW = maxDimension;
destH = minDimension;
} else {
// portrait
destW = minDimension;
destH = maxDimension;
}
} else { // #1
destW = srcW;
destH = srcH;
}
} else if (userW != 0 && userH == 0) { // #2
destW = userW;
destH = round(double(srcH * userW) / srcW);
} else if (userW == 0 && userH != 0) { // #3
destW = round(double(srcW * userH) / srcH);
destH = userH;
} else { // userW != 0 && userH != 0
if (cropImage == 0) { // #4
destW = userW;
destH = userH;
} else { // #5
double userSlope = double(userH) / userW;
double srcSlope = double(srcH) / srcW;
if (srcSlope < userSlope) {
destW = round(double(srcW * userH) / srcH);
destH = userH;
} else {
destW = userW;
destH = round(double(srcH * userW) / srcW);
}
}
}
// prevent zeros
destW = std::max(destW, size_t(1UL));
destH = std::max(destH, size_t(1UL));
}
} // namespace Util
} // namespace ffmpeg
#pragma once
#include "defs.h"
namespace ffmpeg {
/**
* FFMPEG library utility functions.
*/
namespace Util {
std::string generateErrorDesc(int errorCode);
size_t serialize(const AVSubtitle& sub, ByteStorage* out);
bool deserialize(const ByteStorage& buf, AVSubtitle* sub);
size_t size(const AVSubtitle& sub);
void setFormatDimensions(
size_t& destW,
size_t& destH,
size_t userW,
size_t userH,
size_t srcW,
size_t srcH,
size_t minDimension,
size_t maxDimension,
size_t cropImage);
bool validateVideoFormat(const VideoFormat& format);
} // namespace Util
} // namespace ffmpeg
#include <c10/util/Logging.h>
#include <dirent.h>
#include <gtest/gtest.h>
#include "util.h"
TEST(Util, TestSetFormatDimensions) {
// clang-format off
const size_t test_cases[][9] = {
// (userW, userH, srcW, srcH, minDimension, maxDimension, cropImage, destW, destH)
{0, 0, 172, 128, 0, 0, 0, 172, 128}, // #1
{86, 0, 172, 128, 0, 0, 0, 86, 64}, // #2
{64, 0, 128, 172, 0, 0, 0, 64, 86}, // #2
{0, 32, 172, 128, 0, 0, 0, 43, 32}, // #3
{32, 0, 128, 172, 0, 0, 0, 32, 43}, // #3
{60, 50, 172, 128, 0, 0, 0, 60, 50}, // #4
{50, 60, 128, 172, 0, 0, 0, 50, 60}, // #4
{86, 40, 172, 128, 0, 0, 1, 86, 64}, // #5
{86, 92, 172, 128, 0, 0, 1, 124, 92}, // #5
{0, 0, 172, 128, 256, 0, 0, 344, 256}, // #6
{0, 0, 128, 172, 256, 0, 0, 256, 344}, // #6
{0, 0, 128, 172, 0, 344, 0, 256, 344}, // #7
{0, 0, 172, 128, 0, 344, 0, 344, 256}, // #7
{0, 0, 172, 128, 100, 344, 0, 344, 100},// #8
{0, 0, 128, 172, 100, 344, 0, 100, 344} // #8
};
// clang-format onn
for (const auto& tc : test_cases) {
size_t destW = 0;
size_t destH = 0;
ffmpeg::Util::setFormatDimensions(destW, destH, tc[0], tc[1], tc[2], tc[3], tc[4], tc[5], tc[6]);
CHECK(destW == tc[7]);
CHECK(destH == tc[8]);
}
}
#include "video_sampler.h"
#include <c10/util/Logging.h>
#include "util.h"
// www.ffmpeg.org/doxygen/0.5/swscale-example_8c-source.html
namespace ffmpeg {
namespace {
int preparePlanes(
const VideoFormat& fmt,
const uint8_t* buffer,
uint8_t** planes,
int* lineSize) {
int result;
if ((result = av_image_fill_arrays(
planes,
lineSize,
buffer,
(AVPixelFormat)fmt.format,
fmt.width,
fmt.height,
1)) < 0) {
LOG(ERROR) << "av_image_fill_arrays failed, err: "
<< Util::generateErrorDesc(result);
}
return result;
}
int transformImage(
SwsContext* context,
const uint8_t* const srcSlice[],
int srcStride[],
VideoFormat inFormat,
VideoFormat outFormat,
uint8_t* out,
uint8_t* planes[],
int lines[]) {
int result;
if ((result = preparePlanes(outFormat, out, planes, lines)) < 0) {
return result;
}
if ((result = sws_scale(
context, srcSlice, srcStride, 0, inFormat.height, planes, lines)) <
0) {
LOG(ERROR) << "sws_scale failed, err: " << Util::generateErrorDesc(result);
return result;
}
return 0;
}
} // namespace
VideoSampler::VideoSampler(int swsFlags, int64_t loggingUuid)
: swsFlags_(swsFlags), loggingUuid_(loggingUuid) {}
VideoSampler::~VideoSampler() {
cleanUp();
}
void VideoSampler::shutdown() {
cleanUp();
}
bool VideoSampler::init(const SamplerParameters& params) {
cleanUp();
if (params.out.video.cropImage != 0) {
if (!Util::validateVideoFormat(params.out.video)) {
LOG(ERROR) << "Invalid video format"
<< ", width: " << params.out.video.width
<< ", height: " << params.out.video.height
<< ", format: " << params.out.video.format
<< ", minDimension: " << params.out.video.minDimension
<< ", crop: " << params.out.video.cropImage;
return false;
}
scaleFormat_.format = params.out.video.format;
Util::setFormatDimensions(
scaleFormat_.width,
scaleFormat_.height,
params.out.video.width,
params.out.video.height,
params.in.video.width,
params.in.video.height,
0,
0,
1);
if (!(scaleFormat_ == params_.out.video)) { // crop required
cropContext_ = sws_getContext(
params.out.video.width,
params.out.video.height,
(AVPixelFormat)params.out.video.format,
params.out.video.width,
params.out.video.height,
(AVPixelFormat)params.out.video.format,
swsFlags_,
nullptr,
nullptr,
nullptr);
if (!cropContext_) {
LOG(ERROR) << "sws_getContext failed for crop context";
return false;
}
const auto scaleImageSize = av_image_get_buffer_size(
(AVPixelFormat)scaleFormat_.format,
scaleFormat_.width,
scaleFormat_.height,
1);
scaleBuffer_.resize(scaleImageSize);
}
} else {
scaleFormat_ = params.out.video;
}
VLOG(1) << "Input format #" << loggingUuid_ << ", width "
<< params.in.video.width << ", height " << params.in.video.height
<< ", format " << params.in.video.format << ", minDimension "
<< params.in.video.minDimension << ", cropImage "
<< params.in.video.cropImage;
VLOG(1) << "Scale format #" << loggingUuid_ << ", width "
<< scaleFormat_.width << ", height " << scaleFormat_.height
<< ", format " << scaleFormat_.format << ", minDimension "
<< scaleFormat_.minDimension << ", cropImage "
<< scaleFormat_.cropImage;
VLOG(1) << "Crop format #" << loggingUuid_ << ", width "
<< params.out.video.width << ", height " << params.out.video.height
<< ", format " << params.out.video.format << ", minDimension "
<< params.out.video.minDimension << ", cropImage "
<< params.out.video.cropImage;
scaleContext_ = sws_getContext(
params.in.video.width,
params.in.video.height,
(AVPixelFormat)params.in.video.format,
scaleFormat_.width,
scaleFormat_.height,
(AVPixelFormat)scaleFormat_.format,
swsFlags_,
nullptr,
nullptr,
nullptr);
// set output format
params_ = params;
return scaleContext_ != nullptr;
}
int VideoSampler::sample(
const uint8_t* const srcSlice[],
int srcStride[],
ByteStorage* out) {
int result;
// scaled and cropped image
int outImageSize = av_image_get_buffer_size(
(AVPixelFormat)params_.out.video.format,
params_.out.video.width,
params_.out.video.height,
1);
out->ensure(outImageSize);
uint8_t* scalePlanes[4] = {nullptr};
int scaleLines[4] = {0};
// perform scale first
if ((result = transformImage(
scaleContext_,
srcSlice,
srcStride,
params_.in.video,
scaleFormat_,
// for crop use internal buffer
cropContext_ ? scaleBuffer_.data() : out->writableTail(),
scalePlanes,
scaleLines))) {
return result;
}
// is crop required?
if (cropContext_) {
uint8_t* cropPlanes[4] = {nullptr};
int cropLines[4] = {0};
if (params_.out.video.height < scaleFormat_.height) {
// Destination image is wider of source image: cut top and bottom
for (size_t i = 0; i < 4 && scalePlanes[i] != nullptr; ++i) {
scalePlanes[i] += scaleLines[i] *
(scaleFormat_.height - params_.out.video.height) / 2;
}
} else {
// Source image is wider of destination image: cut sides
for (size_t i = 0; i < 4 && scalePlanes[i] != nullptr; ++i) {
scalePlanes[i] += scaleLines[i] *
(scaleFormat_.width - params_.out.video.width) / 2 /
scaleFormat_.width;
}
}
// crop image
if ((result = transformImage(
cropContext_,
scalePlanes,
scaleLines,
params_.out.video,
params_.out.video,
out->writableTail(),
cropPlanes,
cropLines))) {
return result;
}
}
out->append(outImageSize);
return outImageSize;
}
int VideoSampler::sample(AVFrame* frame, ByteStorage* out) {
if (!frame) {
return 0; // no flush for videos
}
return sample(frame->data, frame->linesize, out);
}
int VideoSampler::sample(const ByteStorage* in, ByteStorage* out) {
if (!in) {
return 0; // no flush for videos
}
int result;
uint8_t* inPlanes[4] = {nullptr};
int inLineSize[4] = {0};
if ((result = preparePlanes(
params_.in.video, in->data(), inPlanes, inLineSize)) < 0) {
return result;
}
return sample(inPlanes, inLineSize, out);
}
void VideoSampler::cleanUp() {
if (scaleContext_) {
sws_freeContext(scaleContext_);
scaleContext_ = nullptr;
}
if (cropContext_) {
sws_freeContext(cropContext_);
cropContext_ = nullptr;
scaleBuffer_.clear();
}
}
} // namespace ffmpeg
#pragma once
#include "defs.h"
namespace ffmpeg {
/**
* Class transcode video frames from one format into another
*/
class VideoSampler : public MediaSampler {
public:
VideoSampler(int swsFlags = SWS_AREA, int64_t loggingUuid = 0);
~VideoSampler() override;
// MediaSampler overrides
bool init(const SamplerParameters& params) override;
int sample(const ByteStorage* in, ByteStorage* out) override;
void shutdown() override;
// returns number processed/scaling bytes
int sample(AVFrame* frame, ByteStorage* out);
int getImageBytes() const;
private:
// close resources
void cleanUp();
// helper functions for rescaling, cropping, etc.
int sample(
const uint8_t* const srcSlice[],
int srcStride[],
ByteStorage* out);
private:
VideoFormat scaleFormat_;
SwsContext* scaleContext_{nullptr};
SwsContext* cropContext_{nullptr};
int swsFlags_{SWS_AREA};
std::vector<uint8_t> scaleBuffer_;
int64_t loggingUuid_{0};
};
} // namespace ffmpeg
#include "video_stream.h"
#include <c10/util/Logging.h>
#include "util.h"
namespace ffmpeg {
namespace {
bool operator==(const VideoFormat& x, const AVFrame& y) {
return x.width == y.width && x.height == y.height && x.format == y.format;
}
bool operator==(const VideoFormat& x, const AVCodecContext& y) {
return x.width == y.width && x.height == y.height && x.format == y.pix_fmt;
}
VideoFormat& toVideoFormat(VideoFormat& x, const AVFrame& y) {
x.width = y.width;
x.height = y.height;
x.format = y.format;
return x;
}
VideoFormat& toVideoFormat(VideoFormat& x, const AVCodecContext& y) {
x.width = y.width;
x.height = y.height;
x.format = y.pix_fmt;
return x;
}
} // namespace
VideoStream::VideoStream(
AVFormatContext* inputCtx,
int index,
bool convertPtsToWallTime,
const VideoFormat& format,
int64_t loggingUuid)
: Stream(
inputCtx,
MediaFormat::makeMediaFormat(format, index),
convertPtsToWallTime,
loggingUuid) {}
VideoStream::~VideoStream() {
if (sampler_) {
sampler_->shutdown();
sampler_.reset();
}
}
int VideoStream::initFormat() {
// set output format
if (!Util::validateVideoFormat(format_.format.video)) {
LOG(ERROR) << "Invalid video format"
<< ", width: " << format_.format.video.width
<< ", height: " << format_.format.video.height
<< ", format: " << format_.format.video.format
<< ", minDimension: " << format_.format.video.minDimension
<< ", crop: " << format_.format.video.cropImage;
return -1;
}
// keep aspect ratio
Util::setFormatDimensions(
format_.format.video.width,
format_.format.video.height,
format_.format.video.width,
format_.format.video.height,
codecCtx_->width,
codecCtx_->height,
format_.format.video.minDimension,
format_.format.video.maxDimension,
0);
if (format_.format.video.format == AV_PIX_FMT_NONE) {
format_.format.video.format = codecCtx_->pix_fmt;
}
return format_.format.video.width != 0 && format_.format.video.height != 0 &&
format_.format.video.format != AV_PIX_FMT_NONE
? 0
: -1;
}
int VideoStream::copyFrameBytes(ByteStorage* out, bool flush) {
if (!sampler_) {
sampler_ = std::make_unique<VideoSampler>(SWS_AREA, loggingUuid_);
}
// check if input format gets changed
if (flush ? !(sampler_->getInputFormat().video == *codecCtx_)
: !(sampler_->getInputFormat().video == *frame_)) {
// - reinit sampler
SamplerParameters params;
params.type = format_.type;
params.out = format_.format;
params.in = FormatUnion(0);
flush ? toVideoFormat(params.in.video, *codecCtx_)
: toVideoFormat(params.in.video, *frame_);
if (!sampler_->init(params)) {
return -1;
}
VLOG(1) << "Set input video sampler format"
<< ", width: " << params.in.video.width
<< ", height: " << params.in.video.height
<< ", format: " << params.in.video.format
<< " : output video sampler format"
<< ", width: " << format_.format.video.width
<< ", height: " << format_.format.video.height
<< ", format: " << format_.format.video.format
<< ", minDimension: " << format_.format.video.minDimension
<< ", crop: " << format_.format.video.cropImage;
}
return sampler_->sample(flush ? nullptr : frame_, out);
}
void VideoStream::setHeader(DecoderHeader* header, bool flush) {
Stream::setHeader(header, flush);
if (!flush) { // no frames for video flush
header->keyFrame = frame_->key_frame;
header->fps = av_q2d(av_guess_frame_rate(
inputCtx_, inputCtx_->streams[format_.stream], nullptr));
}
}
} // namespace ffmpeg
#pragma once
#include "stream.h"
#include "video_sampler.h"
namespace ffmpeg {
/**
* Class uses FFMPEG library to decode one video stream.
*/
class VideoStream : public Stream {
public:
VideoStream(
AVFormatContext* inputCtx,
int index,
bool convertPtsToWallTime,
const VideoFormat& format,
int64_t loggingUuid);
~VideoStream() override;
private:
int initFormat() override;
int copyFrameBytes(ByteStorage* out, bool flush) override;
void setHeader(DecoderHeader* header, bool flush) override;
private:
std::unique_ptr<VideoSampler> sampler_;
};
} // namespace ffmpeg
#include "image.h"
#include <ATen/ATen.h>
#include <Python.h>
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
#ifdef _WIN32
PyMODINIT_FUNC PyInit_image(void) {
// No need to do anything.
return NULL;
}
#endif
static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG)
.op("image::encode_png", &encodePNG)
.op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
#pragma once
// Comment
#include <torch/script.h>
#include <torch/torch.h>
#include "read_image_cpu.h"
#include "read_write_file_cpu.h"
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include "writejpeg_cpu.h"
#include "writepng_cpu.h"
#include "jpegcommon.h"
#include <string>
#if JPEG_FOUND
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg);
/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
#endif
#pragma once
// clang-format off
#include <cstdio>
#include <cstddef>
// clang-format on
#if JPEG_FOUND
#include <jpeglib.h>
#include <setjmp.h>
#include <string>
static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};
typedef struct torch_jpeg_error_mgr* torch_jpeg_error_ptr;
void torch_jpeg_error_exit(j_common_ptr cinfo);
#endif
#include "read_image_cpu.h"
#include <cstring>
torch::Tensor decode_image(const torch::Tensor& data) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
auto datap = data.data_ptr<uint8_t>();
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data);
} else {
TORCH_CHECK(
false,
"Unsupported image file. Only jpeg and png ",
"are currently supported.");
}
}
#pragma once
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data);
#include "read_write_file_cpu.h"
// According to
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
// we should use _stat64 for 64-bit file size on Windows.
#ifdef _WIN32
#define VISION_STAT _stat64
#else
#define VISION_STAT stat
#endif
torch::Tensor read_file(std::string filename) {
struct VISION_STAT stat_buf;
int rc = VISION_STAT(filename.c_str(), &stat_buf);
// errno is a variable defined in errno.h
TORCH_CHECK(
rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'");
int64_t size = stat_buf.st_size;
TORCH_CHECK(size > 0, "Expected a non empty file");
#ifdef _WIN32
auto data =
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8)
.clone();
#else
auto data =
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8);
#endif
return data;
}
void write_file(std::string filename, torch::Tensor& data) {
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor");
auto fileBytes = data.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");
TORCH_CHECK(outfile != nullptr, "Error opening output file");
fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile);
}
#pragma once
#include <errno.h>
#include <sys/stat.h>
#include <torch/torch.h>
C10_EXPORT torch::Tensor read_file(std::string filename);
C10_EXPORT void write_file(std::string filename, torch::Tensor& data);
#include "readjpeg_cpu.h"
#include <ATen/ATen.h>
#include <setjmp.h>
#include <string>
#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}
#else
#include <jpeglib.h>
#include "jpegcommon.h"
struct torch_jpeg_mgr {
struct jpeg_source_mgr pub;
const JOCTET* data;
size_t len;
};
static void torch_jpeg_init_source(j_decompress_ptr cinfo) {}
static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
// No more data. Probably an incomplete image; Raise exception.
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated");
longjmp(myerr->setjmp_buffer, 1);
src->pub.next_input_byte = EOI_BUFFER;
src->pub.bytes_in_buffer = 1;
return TRUE;
}
static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) {
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
if (src->pub.bytes_in_buffer < num_bytes) {
// Skipping over all of remaining data; output EOI.
src->pub.next_input_byte = EOI_BUFFER;
src->pub.bytes_in_buffer = 1;
} else {
// Skipping over only some of the remaining data.
src->pub.next_input_byte += num_bytes;
src->pub.bytes_in_buffer -= num_bytes;
}
}
static void torch_jpeg_term_source(j_decompress_ptr cinfo) {}
static void torch_jpeg_set_source_mgr(
j_decompress_ptr cinfo,
const unsigned char* data,
size_t len) {
torch_jpeg_mgr* src;
if (cinfo->src == 0) { // if this is first time; allocate memory
cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)(
(j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr));
}
src = (torch_jpeg_mgr*)cinfo->src;
src->pub.init_source = torch_jpeg_init_source;
src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer;
src->pub.skip_input_data = torch_jpeg_skip_input_data;
src->pub.resync_to_restart = jpeg_resync_to_restart; // default
src->pub.term_source = torch_jpeg_term_source;
// fill the buffers
src->data = (const JOCTET*)data;
src->len = len;
src->pub.bytes_in_buffer = len;
src->pub.next_input_byte = src->data;
}
torch::Tensor decodeJPEG(const torch::Tensor& data) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
auto datap = data.data_ptr<uint8_t>();
// Setup decompression structure
cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(jerr.setjmp_buffer)) {
/* If we get here, the JPEG code has signaled an error.
* We need to clean up the JPEG object.
*/
jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, jerr.jpegLastErrorMsg);
}
jpeg_create_decompress(&cinfo);
torch_jpeg_set_source_mgr(&cinfo, datap, data.numel());
// read info from header.
jpeg_read_header(&cinfo, TRUE);
jpeg_start_decompress(&cinfo);
int height = cinfo.output_height;
int width = cinfo.output_width;
int components = cinfo.output_components;
auto stride = width * components;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), int64_t(components)}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
jpeg_read_scanlines(&cinfo, &ptr, 1);
ptr += stride;
}
jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1});
}
#endif // JPEG_FOUND
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment