"examples/dreambooth/train_dreambooth_inpaint.py" did not exist on "906e4105d7883384d5982eea160c4d7c1eb9327a"
Commit 0fc002df authored by huchen's avatar huchen
Browse files

init the dlexamples new

parent 0e04b692
#pragma once
#include <torch/torch.h>
C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data);
#include "readpng_cpu.h"
// Comment
#include <ATen/ATen.h>
#include <setjmp.h>
#include <string>
#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
torch::Tensor decodePNG(const torch::Tensor& data) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
auto info_ptr = png_create_info_struct(png_ptr);
if (!info_ptr) {
png_destroy_read_struct(&png_ptr, nullptr, nullptr);
// Seems redundant with the if statement. done here to avoid leaking memory.
TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
}
auto datap = data.accessor<unsigned char, 1>().data();
if (setjmp(png_jmpbuf(png_ptr)) != 0) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Internal error.");
}
auto is_png = !png_sig_cmp(datap, 0, 8);
TORCH_CHECK(is_png, "Content is not png!")
struct Reader {
png_const_bytep ptr;
} reader;
reader.ptr = png_const_bytep(datap) + 8;
auto read_callback =
[](png_structp png_ptr, png_bytep output, png_size_t bytes) {
auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
std::copy(reader->ptr, reader->ptr + bytes, output);
reader->ptr += bytes;
};
png_set_sig_bytes(png_ptr, 8);
png_set_read_fn(png_ptr, &reader, read_callback);
png_read_info(png_ptr, info_ptr);
png_uint_32 width, height;
int bit_depth, color_type;
auto retval = png_get_IHDR(
png_ptr,
info_ptr,
&width,
&height,
&bit_depth,
&color_type,
nullptr,
nullptr,
nullptr);
if (retval != 1) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
if (color_type != PNG_COLOR_TYPE_RGB) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")
}
auto tensor =
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (decltype(height) i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
}
#endif // PNG_FOUND
#pragma once
// Comment
#include <torch/torch.h>
#include <string>
C10_EXPORT torch::Tensor decodePNG(const torch::Tensor& data);
#include "writejpeg_cpu.h"
#include <setjmp.h>
#include <string>
#if !JPEG_FOUND
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
TORCH_CHECK(
false, "encodeJPEG: torchvision not compiled with libjpeg support");
}
#else
#include <jpeglib.h>
#include "jpegcommon.h"
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling
struct jpeg_compress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
// Define buffer to write JPEG information to and its size
unsigned long jpegSize = 0;
uint8_t* jpegBuf = NULL;
cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(jerr.setjmp_buffer)) {
/* If we get here, the JPEG code has signaled an error.
* We need to clean up the JPEG object and the buffer.
*/
jpeg_destroy_compress(&cinfo);
if (jpegBuf != NULL) {
free(jpegBuf);
}
TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
}
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
// Get image info
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();
TORCH_CHECK(
channels == 1 || channels == 3,
"The number of channels should be 1 or 3, got: ",
channels);
// Initialize JPEG structure
jpeg_create_compress(&cinfo);
// Set output image information
cinfo.image_width = width;
cinfo.image_height = height;
cinfo.input_components = channels;
cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;
jpeg_set_defaults(&cinfo);
jpeg_set_quality(&cinfo, quality, TRUE);
// Save JPEG output to a buffer
jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);
// Start JPEG compression
jpeg_start_compress(&cinfo, TRUE);
auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();
// Encode JPEG file
while (cinfo.next_scanline < cinfo.image_height) {
jpeg_write_scanlines(&cinfo, &ptr, 1);
ptr += stride;
}
jpeg_finish_compress(&cinfo);
jpeg_destroy_compress(&cinfo);
torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)jpegSize}, options);
// Copy memory from jpeg buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
std::memcpy(outPtr, jpegBuf, sizeof(uint8_t) * outTensor.numel());
free(jpegBuf);
return outTensor;
}
#endif
#pragma once
#include <torch/torch.h>
C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality);
#include "writejpeg_cpu.h"
#include <setjmp.h>
#include <string>
#if !PNG_FOUND
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
TORCH_CHECK(false, "encodePNG: torchvision not compiled with libpng support");
}
#else
#include <png.h>
struct torch_mem_encode {
char* buffer;
size_t size;
};
struct torch_png_error_mgr {
const char* pngLastErrorMsg; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};
typedef torch_png_error_mgr* torch_png_error_mgr_ptr;
void torch_png_warn(png_structp png_ptr, png_const_charp warn_msg) {
/* Display warning to user */
TORCH_WARN_ONCE(warn_msg);
}
void torch_png_error(png_structp png_ptr, png_const_charp error_msg) {
/* png_ptr->err really points to a torch_png_error_mgr struct, so coerce
* pointer */
auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr);
/* Replace the error message on the error structure */
error_ptr->pngLastErrorMsg = error_msg;
/* Return control to the setjmp point */
longjmp(error_ptr->setjmp_buffer, 1);
}
void torch_png_write_data(
png_structp png_ptr,
png_bytep data,
png_size_t length) {
struct torch_mem_encode* p =
(struct torch_mem_encode*)png_get_io_ptr(png_ptr);
size_t nsize = p->size + length;
/* allocate or grow buffer */
if (p->buffer)
p->buffer = (char*)realloc(p->buffer, nsize);
else
p->buffer = (char*)malloc(nsize);
if (!p->buffer)
png_error(png_ptr, "Write Error");
/* copy new bytes to end of buffer */
memcpy(p->buffer + p->size, data, length);
p->size += length;
}
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
// Define compression structures and error handling
png_structp png_write;
png_infop info_ptr;
struct torch_png_error_mgr err_ptr;
// Define output buffer
struct torch_mem_encode buf_info;
buf_info.buffer = NULL;
buf_info.size = 0;
/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(err_ptr.setjmp_buffer)) {
/* If we get here, the PNG code has signaled an error.
* We need to clean up the PNG object and the buffer.
*/
if (info_ptr != NULL) {
png_destroy_info_struct(png_write, &info_ptr);
}
if (png_write != NULL) {
png_destroy_write_struct(&png_write, NULL);
}
if (buf_info.buffer != NULL) {
free(buf_info.buffer);
}
TORCH_CHECK(false, err_ptr.pngLastErrorMsg);
}
// Check that the compression level is between 0 and 9
TORCH_CHECK(
compression_level >= 0 && compression_level <= 9,
"Compression level should be between 0 and 9");
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
// Get image info
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();
TORCH_CHECK(
channels == 1 || channels == 3,
"The number of channels should be 1 or 3, got: ",
channels);
// Initialize PNG structures
png_write = png_create_write_struct(
PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, NULL);
info_ptr = png_create_info_struct(png_write);
// Define custom buffer output
png_set_write_fn(png_write, &buf_info, torch_png_write_data, NULL);
// Set output image information
auto color_type = PNG_COLOR_TYPE_GRAY ? channels == 1 : PNG_COLOR_TYPE_RGB;
png_set_IHDR(
png_write,
info_ptr,
width,
height,
8,
color_type,
PNG_INTERLACE_NONE,
PNG_COMPRESSION_TYPE_DEFAULT,
PNG_FILTER_TYPE_DEFAULT);
// Set image compression level
png_set_compression_level(png_write, compression_level);
// Write file header
png_write_info(png_write, info_ptr);
auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();
// Encode PNG file
for (size_t y = 0; y < height; ++y) {
png_write_row(png_write, ptr);
ptr += stride;
}
// Write EOF
png_write_end(png_write, info_ptr);
// Destroy structures
png_destroy_write_struct(&png_write, &info_ptr);
torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)buf_info.size}, options);
// Copy memory from png buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel());
free(buf_info.buffer);
return outTensor;
}
#endif
#pragma once
#include <torch/torch.h>
C10_EXPORT torch::Tensor encodePNG(
const torch::Tensor& data,
int64_t compression_level);
#include "vision_cpu.h"
template <typename scalar_t>
at::Tensor nms_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
if (dets.numel() == 0)
return at::empty({0}, dets.options().dtype(at::kLong));
auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();
at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
auto x1 = x1_t.data_ptr<scalar_t>();
auto y1 = y1_t.data_ptr<scalar_t>();
auto x2 = x2_t.data_ptr<scalar_t>();
auto y2 = y2_t.data_ptr<scalar_t>();
auto areas = areas_t.data_ptr<scalar_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1)
continue;
keep[num_to_keep++] = i;
auto ix1 = x1[i];
auto iy1 = y1[i];
auto ix2 = x2[i];
auto iy2 = y2[i];
auto iarea = areas[i];
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1)
continue;
auto xx1 = std::max(ix1, x1[j]);
auto yy1 = std::max(iy1, y1[j]);
auto xx2 = std::min(ix2, x2[j]);
auto yy2 = std::min(iy2, y2[j]);
auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr > iou_threshold)
suppressed[j] = 1;
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0));
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
result = nms_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
#include "Video.h"
#include <c10/util/Logging.h>
#include <torch/script.h>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
using namespace std;
using namespace ffmpeg;
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
// #ifdef _WIN32
// #if PY_MAJOR_VERSION < 3
// PyMODINIT_FUNC init_video_reader(void) {
// // No need to do anything.
// return NULL;
// }
// #else
// PyMODINIT_FUNC PyInit_video_reader(void) {
// // No need to do anything.
// return NULL;
// }
// #endif
// #endif
const size_t decoderTimeoutMs = 600000;
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
// returns number of written bytes
template <typename T>
size_t fillTensorList(DecoderOutputMessage& msgs, torch::Tensor& frame) {
const auto& msg = msgs;
T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
if (frameData) {
auto sizeInBytes = msg.payload->length();
memcpy(frameData, msg.payload->data(), sizeInBytes);
}
return sizeof(T);
}
size_t fillVideoTensor(DecoderOutputMessage& msgs, torch::Tensor& videoFrame) {
return fillTensorList<uint8_t>(msgs, videoFrame);
}
size_t fillAudioTensor(DecoderOutputMessage& msgs, torch::Tensor& audioFrame) {
return fillTensorList<float>(msgs, audioFrame);
}
std::array<std::pair<std::string, ffmpeg::MediaType>, 4>::const_iterator
_parse_type(const std::string& stream_string) {
static const std::array<std::pair<std::string, MediaType>, 4> types = {{
{"video", TYPE_VIDEO},
{"audio", TYPE_AUDIO},
{"subtitle", TYPE_SUBTITLE},
{"cc", TYPE_CC},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[stream_string](const std::pair<std::string, MediaType>& p) {
return p.first == stream_string;
});
if (device != types.end()) {
return device;
}
TORCH_CHECK(
false, "Expected one of [audio, video, subtitle, cc] ", stream_string);
}
std::string parse_type_to_string(const std::string& stream_string) {
auto device = _parse_type(stream_string);
return device->first;
}
MediaType parse_type_to_mt(const std::string& stream_string) {
auto device = _parse_type(stream_string);
return device->second;
}
std::tuple<std::string, long> _parseStream(const std::string& streamString) {
TORCH_CHECK(!streamString.empty(), "Stream string must not be empty");
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
std::smatch match;
TORCH_CHECK(
std::regex_match(streamString, match, regex),
"Invalid stream string: '",
streamString,
"'");
std::string type_ = "video";
type_ = parse_type_to_string(match[1].str());
long index_ = -1;
if (match[2].matched) {
try {
index_ = c10::stoi(match[2].str());
} catch (const std::exception&) {
TORCH_CHECK(
false,
"Could not parse device index '",
match[2].str(),
"' in device string '",
streamString,
"'");
}
}
return std::make_tuple(type_, index_);
}
void Video::_getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
std::string stream,
long stream_id = -1,
bool all_streams = false,
double seekFrameMarginUs = 10) {
int64_t videoStartUs = int64_t(videoStartS * 1e6);
params.timeoutMs = decoderTimeoutMs;
params.startOffset = videoStartUs;
params.seekAccuracy = seekFrameMarginUs;
params.headerOnly = false;
params.preventStaleness = false; // not sure what this is about
if (all_streams == true) {
MediaFormat format;
format.stream = -2;
format.type = TYPE_AUDIO;
params.formats.insert(format);
format.type = TYPE_VIDEO;
format.stream = -2;
format.format.video.width = 0;
format.format.video.height = 0;
format.format.video.cropImage = 0;
format.format.video.format = defaultVideoPixelFormat;
params.formats.insert(format);
format.type = TYPE_SUBTITLE;
format.stream = -2;
params.formats.insert(format);
format.type = TYPE_CC;
format.stream = -2;
params.formats.insert(format);
} else {
// parse stream type
MediaType stream_type = parse_type_to_mt(stream);
// TODO: reset params.formats
std::set<MediaFormat> formats;
params.formats = formats;
// Define new format
MediaFormat format;
format.type = stream_type;
format.stream = stream_id;
if (stream_type == TYPE_VIDEO) {
format.format.video.width = 0;
format.format.video.height = 0;
format.format.video.cropImage = 0;
format.format.video.format = defaultVideoPixelFormat;
}
params.formats.insert(format);
}
} // _get decoder params
Video::Video(std::string videoPath, std::string stream) {
// parse stream information
current_stream = _parseStream(stream);
// note that in the initial call we want to get all streams
Video::_getDecoderParams(
0, // video start
0, // headerOnly
get<0>(current_stream), // stream info - remove that
long(-1), // stream_id parsed from info above change to -2
true // read all streams
);
std::string logMessage, logType;
// TODO: add read from memory option
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
// locals
std::vector<double> audioFPS, videoFPS, ccFPS, subsFPS;
std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
std::vector<double> audioTB, videoTB, ccTB, subsTB;
c10::Dict<std::string, std::vector<double>> audioMetadata;
c10::Dict<std::string, std::vector<double>> videoMetadata;
// calback and metadata defined in struct
succeeded = decoder.init(params, std::move(callback), &metadata);
if (succeeded) {
for (const auto& header : metadata) {
double fps = double(header.fps);
double duration = double(header.duration) * 1e-6; // * timeBase;
if (header.format.type == TYPE_VIDEO) {
videoFPS.push_back(fps);
videoDuration.push_back(duration);
} else if (header.format.type == TYPE_AUDIO) {
audioFPS.push_back(fps);
audioDuration.push_back(duration);
} else if (header.format.type == TYPE_CC) {
ccFPS.push_back(fps);
ccDuration.push_back(duration);
} else if (header.format.type == TYPE_SUBTITLE) {
subsFPS.push_back(fps);
subsDuration.push_back(duration);
};
}
}
audioMetadata.insert("duration", audioDuration);
audioMetadata.insert("framerate", audioFPS);
videoMetadata.insert("duration", videoDuration);
videoMetadata.insert("fps", videoFPS);
streamsMetadata.insert("video", videoMetadata);
streamsMetadata.insert("audio", audioMetadata);
succeeded = Video::setCurrentStream(stream);
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
if (get<1>(current_stream) != -1) {
LOG(INFO)
<< "Stream index set to " << get<1>(current_stream)
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
}
} // video
bool Video::setCurrentStream(std::string stream = "video") {
if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
current_stream = _parseStream(stream);
}
double ts = 0;
if (seekTS > 0) {
ts = seekTS;
}
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
// calback and metadata defined in Video.h
return (decoder.init(params, std::move(callback), &metadata));
}
std::tuple<std::string, int64_t> Video::getCurrentStream() const {
return current_stream;
}
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
getStreamMetadata() const {
return streamsMetadata;
}
void Video::Seek(double ts) {
// initialize the class variables used for seeking and retrurn
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
// calback and metadata defined in Video.h
succeeded = decoder.init(params, std::move(callback), &metadata);
LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
}
std::tuple<torch::Tensor, double> Video::Next() {
// if failing to decode simply return a null tensor (note, should we
// raise an exeption?)
double frame_pts_s;
torch::Tensor outFrame = torch::zeros({0}, torch::kByte);
// decode single frame
DecoderOutputMessage out;
int64_t res = decoder.decode(&out, decoderTimeoutMs);
// if successfull
if (res == 0) {
frame_pts_s = double(double(out.header.pts) * 1e-6);
auto header = out.header;
const auto& format = header.format;
// initialize the output variables based on type
if (format.type == TYPE_VIDEO) {
// note: this can potentially be optimized
// by having the global tensor that we fill at decode time
// (would avoid allocations)
int outHeight = format.format.video.height;
int outWidth = format.format.video.width;
int numChannels = 3;
outFrame = torch::zeros({outHeight, outWidth, numChannels}, torch::kByte);
auto numberWrittenBytes = fillVideoTensor(out, outFrame);
outFrame = outFrame.permute({2, 0, 1});
} else if (format.type == TYPE_AUDIO) {
int outAudioChannels = format.format.audio.channels;
int bytesPerSample = av_get_bytes_per_sample(
static_cast<AVSampleFormat>(format.format.audio.format));
int frameSizeTotal = out.payload->length();
CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
int numAudioSamples =
frameSizeTotal / (outAudioChannels * bytesPerSample);
outFrame =
torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
auto numberWrittenBytes = fillAudioTensor(out, outFrame);
}
// currently not supporting other formats (will do soon)
out.payload.reset();
} else if (res == 61) {
LOG(INFO) << "Decoder ran out of frames (error 61)\n";
} else {
LOG(ERROR) << "Decoder failed with ERROR_CODE " << res;
}
std::tuple<torch::Tensor, double> result = {outFrame, frame_pts_s};
return result;
}
#pragma once
#include <map>
#include <regex>
#include <string>
#include <vector>
#include <ATen/ATen.h>
#include <Python.h>
#include <c10/util/Logging.h>
#include <torch/script.h>
#include <exception>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
using namespace ffmpeg;
struct Video : torch::CustomClassHolder {
std::tuple<std::string, long> current_stream; // stream type, id
// global video metadata
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
streamsMetadata;
public:
Video(std::string videoPath, std::string stream);
std::tuple<std::string, int64_t> getCurrentStream() const;
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
getStreamMetadata() const;
void Seek(double ts);
bool setCurrentStream(std::string stream);
std::tuple<torch::Tensor, double> Next();
private:
bool video_any_frame = false; // add this to input parameters?
bool succeeded = false; // decoder init flag
// seekTS and doSeek act as a flag - if it's not set, next function simply
// retruns the next frame. If it's set, we look at the global seek
// time in comination with any_frame settings
double seekTS = -1;
bool doSeek = false;
void _getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
std::string stream,
long stream_id,
bool all_streams,
double seekFrameMarginUs); // this needs to be improved
std::map<std::string, std::vector<double>> streamTimeBase; // not used
DecoderInCallback callback = nullptr;
std::vector<DecoderMetadata> metadata;
protected:
SyncDecoder decoder;
DecoderParameters params;
}; // struct Video
#ifndef REGISTER_H
#define REGISTER_H
#include "Video.h"
namespace {
static auto registerVideo =
torch::class_<Video>("torchvision", "Video")
.def(torch::init<std::string, std::string>())
.def("get_current_stream", &Video::getCurrentStream)
.def("set_current_stream", &Video::setCurrentStream)
.def("get_metadata", &Video::getStreamMetadata)
.def("seek", &Video::Seek)
.def("next", &Video::Next);
} // namespace
#endif
#include "VideoReader.h"
#include <ATen/ATen.h>
#include <Python.h>
#include <c10/util/Logging.h>
#include <exception>
#include "memory_buffer.h"
#include "sync_decoder.h"
using namespace std;
using namespace ffmpeg;
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
#ifdef _WIN32
#if PY_MAJOR_VERSION < 3
PyMODINIT_FUNC init_video_reader(void) {
// No need to do anything.
return NULL;
}
#else
PyMODINIT_FUNC PyInit_video_reader(void) {
// No need to do anything.
return NULL;
}
#endif
#endif
namespace video_reader {
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE};
const size_t decoderTimeoutMs = 600000;
// A jitter can be added to the end of the range to avoid conversion/rounding
// error, small value 100us won't be enough to select the next frame, but enough
// to compensate rounding error due to the multiple conversions.
const size_t timeBaseJitterUs = 100;
DecoderParameters getDecoderParams(
int64_t videoStartUs,
int64_t videoEndUs,
double seekFrameMarginUs,
int64_t getPtsOnly,
int64_t readVideoStream,
int videoWidth,
int videoHeight,
int videoMinDimension,
int videoMaxDimension,
int64_t readAudioStream,
int audioSamples,
int audioChannels) {
DecoderParameters params;
params.headerOnly = getPtsOnly != 0;
params.seekAccuracy = seekFrameMarginUs;
params.startOffset = videoStartUs;
params.endOffset = videoEndUs;
params.timeoutMs = decoderTimeoutMs;
params.preventStaleness = false;
if (readVideoStream == 1) {
MediaFormat videoFormat(0);
videoFormat.type = TYPE_VIDEO;
videoFormat.format.video.format = defaultVideoPixelFormat;
videoFormat.format.video.width = videoWidth;
videoFormat.format.video.height = videoHeight;
videoFormat.format.video.minDimension = videoMinDimension;
videoFormat.format.video.maxDimension = videoMaxDimension;
params.formats.insert(videoFormat);
}
if (readAudioStream == 1) {
MediaFormat audioFormat;
audioFormat.type = TYPE_AUDIO;
audioFormat.format.audio.format = defaultAudioSampleFormat;
audioFormat.format.audio.samples = audioSamples;
audioFormat.format.audio.channels = audioChannels;
params.formats.insert(audioFormat);
}
return params;
}
// returns number of written bytes
template <typename T>
size_t fillTensor(
std::vector<DecoderOutputMessage>& msgs,
torch::Tensor& frame,
torch::Tensor& framePts,
int64_t num,
int64_t den) {
if (msgs.empty()) {
return 0;
}
T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
int64_t* framePtsData = framePts.data_ptr<int64_t>();
CHECK_EQ(framePts.size(0), msgs.size());
size_t avgElementsInFrame = frame.numel() / msgs.size();
size_t offset = 0;
for (size_t i = 0; i < msgs.size(); ++i) {
const auto& msg = msgs[i];
// convert pts into original time_base
AVRational avr = AVRational{(int)num, (int)den};
framePtsData[i] = av_rescale_q(msg.header.pts, timeBaseQ, avr);
VLOG(2) << "PTS type: " << sizeof(T) << ", us: " << msg.header.pts
<< ", original: " << framePtsData[i];
if (frameData) {
auto sizeInBytes = msg.payload->length();
memcpy(frameData + offset, msg.payload->data(), sizeInBytes);
if (sizeof(T) == sizeof(uint8_t)) {
// Video - move by allocated frame size
offset += avgElementsInFrame / sizeof(T);
} else {
// Audio - move by number of samples
offset += sizeInBytes / sizeof(T);
}
}
}
return offset * sizeof(T);
}
size_t fillVideoTensor(
std::vector<DecoderOutputMessage>& msgs,
torch::Tensor& videoFrame,
torch::Tensor& videoFramePts,
int64_t num,
int64_t den) {
return fillTensor<uint8_t>(msgs, videoFrame, videoFramePts, num, den);
}
size_t fillAudioTensor(
std::vector<DecoderOutputMessage>& msgs,
torch::Tensor& audioFrame,
torch::Tensor& audioFramePts,
int64_t num,
int64_t den) {
return fillTensor<float>(msgs, audioFrame, audioFramePts, num, den);
}
void offsetsToUs(
double& seekFrameMargin,
int64_t readVideoStream,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen,
int64_t& videoStartUs,
int64_t& videoEndUs) {
seekFrameMargin *= AV_TIME_BASE;
videoStartUs = 0;
videoEndUs = -1;
if (readVideoStream) {
AVRational vr = AVRational{(int)videoTimeBaseNum, (int)videoTimeBaseDen};
if (videoStartPts > 0) {
videoStartUs = av_rescale_q(videoStartPts, vr, timeBaseQ);
}
if (videoEndPts > 0) {
// Add jitter to the end of the range to avoid conversion/rounding error.
// Small value 100us won't be enough to select the next frame, but enough
// to compensate rounding error due to the multiple conversions.
videoEndUs = timeBaseJitterUs + av_rescale_q(videoEndPts, vr, timeBaseQ);
}
} else if (readAudioStream) {
AVRational ar = AVRational{(int)audioTimeBaseNum, (int)audioTimeBaseDen};
if (audioStartPts > 0) {
videoStartUs = av_rescale_q(audioStartPts, ar, timeBaseQ);
}
if (audioEndPts > 0) {
// Add jitter to the end of the range to avoid conversion/rounding error.
// Small value 100us won't be enough to select the next frame, but enough
// to compensate rounding error due to the multiple conversions.
videoEndUs = timeBaseJitterUs + av_rescale_q(audioEndPts, ar, timeBaseQ);
}
}
}
torch::List<torch::Tensor> readVideo(
bool isReadFile,
const torch::Tensor& input_video,
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
int64_t videoStartUs, videoEndUs;
offsetsToUs(
seekFrameMargin,
readVideoStream,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen,
videoStartUs,
videoEndUs);
DecoderParameters params = getDecoderParams(
videoStartUs, // videoStartPts
videoEndUs, // videoEndPts
seekFrameMargin, // seekFrameMargin
getPtsOnly, // getPtsOnly
readVideoStream, // readVideoStream
width, // width
height, // height
minDimension, // minDimension
maxDimension, // maxDimension
readAudioStream, // readAudioStream
audioSamples, // audioSamples
audioChannels // audioChannels
);
SyncDecoder decoder;
std::vector<DecoderOutputMessage> audioMessages, videoMessages;
DecoderInCallback callback = nullptr;
std::string logMessage, logType;
if (isReadFile) {
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
} else {
callback = MemoryBuffer::getCallback(
input_video.data_ptr<uint8_t>(), input_video.size(0));
logType = "memory";
logMessage = std::to_string(input_video.size(0));
}
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] has started";
const auto now = std::chrono::system_clock::now();
bool succeeded;
DecoderMetadata audioMetadata, videoMetadata;
std::vector<DecoderMetadata> metadata;
if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
for (const auto& header : metadata) {
if (header.format.type == TYPE_VIDEO) {
videoMetadata = header;
} else if (header.format.type == TYPE_AUDIO) {
audioMetadata = header;
}
}
int res;
DecoderOutputMessage msg;
while (0 == (res = decoder.decode(&msg, decoderTimeoutMs))) {
if (msg.header.format.type == TYPE_VIDEO) {
videoMessages.push_back(std::move(msg));
}
if (msg.header.format.type == TYPE_AUDIO) {
audioMessages.push_back(std::move(msg));
}
msg.payload.reset();
}
} else {
LOG(ERROR) << "Decoder initialization has failed";
}
const auto then = std::chrono::system_clock::now();
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] has finished, "
<< std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count()
<< " us";
decoder.shutdown();
// video section
torch::Tensor videoFrame = torch::zeros({0}, torch::kByte);
torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
if (succeeded && readVideoStream == 1) {
if (!videoMessages.empty()) {
const auto& header = videoMetadata;
const auto& format = header.format.format.video;
int numVideoFrames = videoMessages.size();
int outHeight = format.height;
int outWidth = format.width;
int numChannels = 3; // decoder guarantees the default AV_PIX_FMT_RGB24
size_t expectedWrittenBytes = 0;
if (getPtsOnly == 0) {
videoFrame = torch::zeros(
{numVideoFrames, outHeight, outWidth, numChannels}, torch::kByte);
expectedWrittenBytes =
(size_t)numVideoFrames * outHeight * outWidth * numChannels;
}
videoFramePts = torch::zeros({numVideoFrames}, torch::kLong);
VLOG(2) << "video duration: " << header.duration
<< ", fps: " << header.fps << ", num: " << header.num
<< ", den: " << header.den << ", num frames: " << numVideoFrames;
auto numberWrittenBytes = fillVideoTensor(
videoMessages, videoFrame, videoFramePts, header.num, header.den);
CHECK_EQ(numberWrittenBytes, expectedWrittenBytes);
videoTimeBase = torch::zeros({2}, torch::kInt);
int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
videoTimeBaseData[0] = header.num;
videoTimeBaseData[1] = header.den;
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = header.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
AVRational vr = AVRational{(int)header.num, (int)header.den};
videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, vr);
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] filled video tensors";
} else {
VLOG(1) << "Miss video stream";
}
}
// audio section
torch::Tensor audioFrame = torch::zeros({0}, torch::kFloat);
torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
if (succeeded && readAudioStream == 1) {
if (!audioMessages.empty()) {
const auto& header = audioMetadata;
const auto& format = header.format.format.audio;
int64_t outAudioChannels = format.channels;
int bytesPerSample =
av_get_bytes_per_sample(static_cast<AVSampleFormat>(format.format));
int numAudioFrames = audioMessages.size();
int64_t numAudioSamples = 0;
if (getPtsOnly == 0) {
int64_t frameSizeTotal = 0;
for (auto const& audioMessage : audioMessages) {
frameSizeTotal += audioMessage.payload->length();
}
CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
numAudioSamples = frameSizeTotal / (outAudioChannels * bytesPerSample);
audioFrame =
torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
}
audioFramePts = torch::zeros({numAudioFrames}, torch::kLong);
VLOG(2) << "audio duration: " << header.duration
<< ", channels: " << format.channels
<< ", sample rate: " << format.samples << ", num: " << header.num
<< ", den: " << header.den;
auto numberWrittenBytes = fillAudioTensor(
audioMessages, audioFrame, audioFramePts, header.num, header.den);
CHECK_EQ(
numberWrittenBytes,
numAudioSamples * outAudioChannels * sizeof(float));
audioTimeBase = torch::zeros({2}, torch::kInt);
int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
audioTimeBaseData[0] = header.num;
audioTimeBaseData[1] = header.den;
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = format.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
AVRational ar = AVRational{(int)header.num, (int)header.den};
audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, ar);
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] filled audio tensors";
} else {
VLOG(1) << "Miss audio stream";
}
}
torch::List<torch::Tensor> result;
result.push_back(std::move(videoFrame));
result.push_back(std::move(videoFramePts));
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioFrame));
result.push_back(std::move(audioFramePts));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] about to return";
return result;
}
torch::List<torch::Tensor> readVideoFromMemory(
torch::Tensor input_video,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
return readVideo(
false,
input_video,
"", // videoPath
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}
torch::List<torch::Tensor> readVideoFromFile(
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
torch::Tensor dummy_input_video = torch::ones({0});
return readVideo(
true,
dummy_input_video,
videoPath,
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}
torch::List<torch::Tensor> probeVideo(
bool isReadFile,
const torch::Tensor& input_video,
std::string videoPath) {
DecoderParameters params = getDecoderParams(
0, // videoStartUs
-1, // videoEndUs
0, // seekFrameMargin
1, // getPtsOnly
1, // readVideoStream
0, // width
0, // height
0, // minDimension
0, // maxDimension
1, // readAudioStream
0, // audioSamples
0 // audioChannels
);
SyncDecoder decoder;
DecoderInCallback callback = nullptr;
std::string logMessage, logType;
if (isReadFile) {
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
} else {
callback = MemoryBuffer::getCallback(
input_video.data_ptr<uint8_t>(), input_video.size(0));
logType = "memory";
logMessage = std::to_string(input_video.size(0));
}
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] has started";
const auto now = std::chrono::system_clock::now();
bool succeeded;
bool gotAudio = false, gotVideo = false;
DecoderMetadata audioMetadata, videoMetadata;
std::vector<DecoderMetadata> metadata;
if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
for (const auto& header : metadata) {
if (header.format.type == TYPE_VIDEO) {
gotVideo = true;
videoMetadata = header;
} else if (header.format.type == TYPE_AUDIO) {
gotAudio = true;
audioMetadata = header;
}
}
const auto then = std::chrono::system_clock::now();
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] has finished, "
<< std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count()
<< " us";
} else {
LOG(ERROR) << "Decoder initialization has failed";
}
decoder.shutdown();
// video section
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
if (succeeded && gotVideo) {
videoTimeBase = torch::zeros({2}, torch::kInt);
int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
const auto& header = videoMetadata;
const auto& media = header.format;
videoTimeBaseData[0] = header.num;
videoTimeBaseData[1] = header.den;
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = header.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
AVRational avr = AVRational{(int)header.num, (int)header.den};
videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
VLOG(2) << "Prob fps: " << header.fps << ", duration: " << header.duration
<< ", num: " << header.num << ", den: " << header.den;
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] filled video tensors";
} else {
LOG(ERROR) << "Miss video stream";
}
// audio section
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
if (succeeded && gotAudio) {
audioTimeBase = torch::zeros({2}, torch::kInt);
int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
const auto& header = audioMetadata;
const auto& media = header.format;
const auto& format = media.format.audio;
audioTimeBaseData[0] = header.num;
audioTimeBaseData[1] = header.den;
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = format.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
AVRational avr = AVRational{(int)header.num, (int)header.den};
audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
VLOG(2) << "Prob sample rate: " << format.samples
<< ", duration: " << header.duration << ", num: " << header.num
<< ", den: " << header.den;
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] filled audio tensors";
} else {
VLOG(1) << "Miss audio stream";
}
torch::List<torch::Tensor> result;
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] is about to return";
return result;
}
torch::List<torch::Tensor> probeVideoFromMemory(torch::Tensor input_video) {
return probeVideo(false, input_video, "");
}
torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
torch::Tensor dummy_input_video = torch::ones({0});
return probeVideo(true, dummy_input_video, videoPath);
}
} // namespace video_reader
static auto registry = torch::RegisterOperators()
.op("video_reader::read_video_from_memory",
&video_reader::readVideoFromMemory)
.op("video_reader::read_video_from_file",
&video_reader::readVideoFromFile)
.op("video_reader::probe_video_from_memory",
&video_reader::probeVideoFromMemory)
.op("video_reader::probe_video_from_file",
&video_reader::probeVideoFromFile);
#pragma once
#include <torch/extension.h>
#include "../macros.h"
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
VISION_API at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
VISION_API at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned);
VISION_API at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
VISION_API at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
VISION_API at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width);
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include "cuda_helpers.h"
#include <cmath>
#include <iostream>
#include <tuple>
const unsigned int CUDA_NUM_THREADS = 1024;
const int kMaxParallelImgs = 32;
inline unsigned int GET_BLOCKS(const unsigned int N) {
unsigned int kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(
const scalar_t* in,
const int height,
const int width,
scalar_t h,
scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = in[h_low * width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = in[h_low * width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = in[h_high * width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = in[h_high * width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(
const int n,
const scalar_t* input_ptr,
const scalar_t* offset_ptr,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dil_h,
const int dil_w,
const int batch_sz,
const int n_in_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int out_b = (index / (out_w * out_h)) % batch_sz;
const int in_c = index / (out_w * out_h * batch_sz);
const int out_c = in_c * weight_h * weight_w;
int c_per_offset_grp = n_in_channels / n_offset_grps;
const int grp_idx = in_c / c_per_offset_grp;
columns_ptr +=
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
out_y * out_w + out_x);
input_ptr +=
(out_b * (n_in_channels * height * width) + in_c * (height * width));
offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w *
out_h * out_w;
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const int offset_idx = 2 * (i * weight_w + j);
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
}
}
static void deformable_im2col(
const at::Tensor input,
const at::Tensor data_offset,
int n_in_channels,
int height,
int width,
int weight_h,
int weight_w,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dil_h,
int dil_w,
int out_h,
int out_w,
int parallel_imgs,
int deformable_group,
at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col_gpu", ([&] {
deformable_im2col_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
data_col.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
static int get_greatest_divisor_below_bound(int n, int bound) {
for (int k = bound; k > 1; --k) {
if (n % k == 0) {
return k;
}
}
return 1;
}
at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor input = input_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor bias = bias_param.contiguous();
TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
int in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
int out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
TORCH_CHECK(
weight_h > 0 && weight_w > 0,
"weight_h: ",
weight_h,
" weight_w: ",
weight_w);
TORCH_CHECK(
stride_h > 0 && stride_w > 0,
"stride_h: ",
stride_h,
" stride_w: ",
stride_w);
TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w);
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"offset.shape[1] is not valid: got: ",
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (",
offset.size(2),
", ",
offset.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
out_h,
" out_w: ",
out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
out_channels,
out_h,
out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs * out_h,
out_w},
out.options());
// Separate channels into convolution groups
out_buf = out_buf.view({out_buf.size(0),
n_weight_grps,
out_buf.size(1) / n_weight_grps,
out_buf.size(2),
out_buf.size(3)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
// Sample points and perform convolution
auto columns = at::zeros(
{in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
input[b],
offset[b],
in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
out_buf[b][g] = out_buf[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
columns = columns.view(
{columns.size(0) * columns.size(1), columns.size(2)});
}
out_buf = out_buf.view({batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs,
out_h,
out_w});
out_buf.transpose_(1, 2);
out.copy_(out_buf);
out = out.view({batch_sz, out_channels, out_h, out_w});
return out + bias.view({1, out_channels, 1, 1});
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n,
const scalar_t* col,
const scalar_t* offset_ptr,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int b = (index / (out_w * out_h)) % batch_sz;
const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;
const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);
int c_per_offset_grp = channels / n_offset_grps;
const int offset_grp = c / c_per_offset_grp;
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w *
out_h * out_w;
const int offset_h_ptr =
((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x;
const int offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
for (int dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) {
int yp = int(y) + dy;
int xp = int(x) + dx;
if (0 <= yp && yp < height && 0 <= xp && xp < width &&
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, weight * col[index]);
}
}
}
}
}
static void compute_grad_input(
const at::Tensor columns,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_im) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_gpu", ([&] {
deformable_col2im_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
grad_im.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in compute_grad_input: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(
const scalar_t* im_data,
const int height,
const int width,
scalar_t y,
scalar_t x,
bool is_y_direction) {
int y_l = floor(y);
int x_l = floor(x);
int y_h = y_l + 1;
int x_h = x_l + 1;
bool valid_y_l = 0 <= y_l && y_l < height;
bool valid_y_h = 0 <= y_h && y_h < height;
bool valid_x_l = 0 <= x_l && x_l < width;
bool valid_x_h = 0 <= x_h && x_h < width;
scalar_t zero = 0;
scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;
scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;
scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;
scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;
if (is_y_direction) {
scalar_t dx = x - x_l;
return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);
} else {
scalar_t dy = y - y_l;
return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(
const int n,
const scalar_t* col_ptr,
const scalar_t* im_ptr,
const scalar_t* offset_ptr,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int offset_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_offset) {
CUDA_1D_KERNEL_LOOP(index, n) {
scalar_t val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps;
col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz *
out_w * out_h;
im_ptr +=
(b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w *
out_h * out_w;
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int offset_h_ptr =
(((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x);
const int offset_w_ptr =
(((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
val += weight * col_ptr[col_pos];
im_ptr += height * width;
}
grad_offset[index] = val;
}
}
static void compute_grad_offset(
const at::Tensor columns,
const at::Tensor input,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_offset) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
grad_offset.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in compute_grad_offset: %s\n", cudaGetErrorString(err));
}
}
static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
long n_in_channels = input.size(1);
long in_h = input.size(2);
long in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;
long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
// Separate into blocks
grad_input = grad_input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_out = grad_out.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w}).permute({0, 2, 3, 1, 4, 5});
weight = weight.reshape({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
columns.zero_();
// Separate into weight groups
for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
compute_grad_offset(
columns,
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_offset[elt]);
compute_grad_input(
columns,
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_input[elt]);
}
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset);
}
static at::Tensor deform_conv_backward_parameters_cuda(
at::Tensor input,
const at::Tensor& weight,
at::Tensor offset,
const at::Tensor& grad_out,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
long n_in_channels = input.size(1);
long in_h = input.size(2);
long in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
long out_h = grad_out.size(2);
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w}
).permute({0, 2, 3, 1, 4, 5}).contiguous();
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_weight = grad_weight.reshape({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
auto columns = at::empty(
{n_weight_grps,
n_in_channels * weight_w * weight_h / n_weight_grps,
n_parallel_imgs * out_h * out_w},
input.options());
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col(
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]);
}
}
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
grad_weight.size(4)});
return grad_weight;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor grad_out = grad_out_param.contiguous();
at::Tensor input = input_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor bias = bias_param.contiguous();
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv_backward_input_cuda(
input,
weight,
offset,
grad_out,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_weight = deform_conv_backward_parameters_cuda(
input,
weight,
offset,
grad_out,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto value = grad_out.sum({0, 2, 3});
auto grad_bias = at::ones_like(bias) * value;
return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment