Unverified Commit 9ba1f760 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Feature] : Add Deformable Conv2d TensorRT Plugin (#858)

* add dcn tensorrt plugin

* prepare for fp16 support

* fix for lint

* limit column buffer

* add docstring to memcpyPermute
parent 57f3a614
...@@ -66,11 +66,16 @@ ...@@ -66,11 +66,16 @@
#ifndef DEFORM_CONV_CUDA_KERNEL_CUH #ifndef DEFORM_CONV_CUDA_KERNEL_CUH
#define DEFORM_CONV_CUDA_KERNEL_CUH #define DEFORM_CONV_CUDA_KERNEL_CUH
#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS #ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp" #include "parrots_cuda_helper.hpp"
#else #else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp"
#endif #endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT
template <typename T> template <typename T>
__device__ T deformable_im2col_bilinear(const T *input, const int data_width, __device__ T deformable_im2col_bilinear(const T *input, const int data_width,
......
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
using mmcv::TensorDesc;
template <class scalar_t>
__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n,
TensorDesc ts_src_stride,
TensorDesc ts_dst_stride,
TensorDesc ts_permute) {
const int src_dim = ts_src_stride.dim;
int *src_stride = &(ts_src_stride.stride[0]);
int *dst_stride = &(ts_dst_stride.stride[0]);
int *permute = &(ts_permute.shape[0]);
CUDA_1D_KERNEL_LOOP(index, n) {
size_t dst_index = index;
size_t src_index = 0;
for (int i = 0; i < src_dim; ++i) {
int dim_index = dst_index / dst_stride[i];
dst_index = dst_index % dst_stride[i];
src_index += dim_index * src_stride[permute[i]];
}
dst[index] = src[src_index];
}
}
template <class scalar_t>
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
int *permute, int src_dim, cudaStream_t stream) {
size_t copy_size = 1;
TensorDesc ts_permute;
memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int));
TensorDesc ts_src_stride;
TensorDesc ts_dst_stride;
ts_src_stride.dim = src_dim;
ts_dst_stride.dim = src_dim;
int *src_stride = &(ts_src_stride.stride[0]);
int *dst_stride = &(ts_dst_stride.stride[0]);
int *dst_size = &(ts_dst_stride.shape[0]);
src_stride[src_dim - 1] = 1;
dst_stride[src_dim - 1] = 1;
for (int i = src_dim - 1; i >= 0; --i) {
dst_size[i] = src_size[permute[i]];
if (i < src_dim - 1) {
src_stride[i] = src_stride[i + 1] * src_size[i + 1];
}
}
for (int i = src_dim - 1; i >= 0; --i) {
copy_size *= dst_size[i];
if (i < src_dim - 1) {
dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1];
}
}
copy_permute_kernel<scalar_t>
<<<GET_BLOCKS(copy_size), THREADS_PER_BLOCK, 0, stream>>>(
dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute);
}
template void memcpyPermute<float>(float *dst, const float *src, int *src_size,
int *permute, int src_dim,
cudaStream_t stream);
#include "trt_deform_conv.hpp"
#include <assert.h>
#include <chrono>
#include "trt_serialize.hpp"
void DeformConvForwardCUDAKernelLauncher_float(
const float *input, const float *weight, const float *offset, float *output,
void *workspace, int batchSize, int nInputPlane, int inputHeight,
int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW,
int padH, int dilationW, int dilationH, int group, int deformable_group,
int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream);
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"MMCVDeformConv2d"};
} // namespace
nvinfer1::PluginFieldCollection DeformableConvPluginDynamicCreator::mFC{};
std::vector<nvinfer1::PluginField>
DeformableConvPluginDynamicCreator::mPluginAttributes;
DeformableConvPluginDynamic::DeformableConvPluginDynamic(
const std::string &name, const nvinfer1::Dims &stride,
const nvinfer1::Dims &padding, const nvinfer1::Dims &dilation,
const int deformableGroup, const int group, int im2colStep)
: mLayerName(name),
mStride(stride),
mPadding(padding),
mDilation(dilation),
mDeformableGroup(deformableGroup),
mGroup(group),
mIm2colStep(im2colStep) {
cublasCreate(&m_cublas_handle);
}
DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
const void *data,
size_t length)
: mLayerName(name) {
deserialize_value(&data, &length, &mStride);
deserialize_value(&data, &length, &mPadding);
deserialize_value(&data, &length, &mDilation);
deserialize_value(&data, &length, &mDeformableGroup);
deserialize_value(&data, &length, &mGroup);
deserialize_value(&data, &length, &mIm2colStep);
cublasCreate(&m_cublas_handle);
}
DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {
// destroy cublas handle
cublasDestroy(m_cublas_handle);
}
nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const {
DeformableConvPluginDynamic *plugin =
new DeformableConvPluginDynamic(mLayerName, mStride, mPadding, mDilation,
mDeformableGroup, mGroup, mIm2colStep);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[2].d[0];
ret.d[2] = inputs[1].d[2];
ret.d[3] = inputs[1].d[3];
return ret;
}
bool DeformableConvPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
if (pos == 0) {
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
}
}
void DeformableConvPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
size_t DeformableConvPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
int batch_size = inputs[0].dims.d[0];
int nInputPlane = inputs[0].dims.d[1];
int inputHeight = inputs[0].dims.d[2];
int inputWidth = inputs[0].dims.d[3];
int nOutputPlane = outputs[0].dims.d[1];
int outputHeight = outputs[0].dims.d[2];
int outputWidth = outputs[0].dims.d[3];
int kW = inputs[2].dims.d[2];
int kH = inputs[2].dims.d[3];
int im2col_step = std::min(batch_size, mIm2colStep);
size_t col_size =
mmcv::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight *
outputWidth * sizeof_dtype);
size_t out_size = 0;
if (im2col_step != 1)
out_size = mmcv::getAlignedSize(batch_size * nOutputPlane * outputHeight *
outputWidth * sizeof_dtype);
return col_size + out_size;
}
int DeformableConvPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
if (m_cuda_stream != stream) {
cublasSetStream(m_cublas_handle, stream);
m_cuda_stream = stream;
}
int batch_size = inputDesc[0].dims.d[0];
int inputChannel = inputDesc[0].dims.d[1];
int inputHeight = inputDesc[0].dims.d[2];
int inputWidth = inputDesc[0].dims.d[3];
int outputChannel = outputDesc[0].dims.d[1];
int kernelHeight = inputDesc[2].dims.d[2];
int kernelWidth = inputDesc[2].dims.d[3];
const void *x = inputs[0];
const void *offset = inputs[1];
const void *weight = inputs[2];
void *output = outputs[0];
int im2col_step = std::min(batch_size, mIm2colStep);
// TODO: add fp16 support
auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
DeformConvForwardCUDAKernelLauncher_float(
(float *)x, (float *)weight, (float *)offset, (float *)output,
workSpace, batch_size, inputChannel, inputHeight, inputWidth,
outputChannel, kernelWidth, kernelHeight, mStride.d[0], mStride.d[1],
mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup,
mDeformableGroup, im2col_step, m_cublas_handle, stream);
break;
default:
return 1;
break;
}
return 0;
}
nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return inputTypes[0];
}
// IPluginV2 Methods
const char *DeformableConvPluginDynamic::getPluginType() const {
return PLUGIN_NAME;
}
const char *DeformableConvPluginDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}
int DeformableConvPluginDynamic::getNbOutputs() const { return 1; }
int DeformableConvPluginDynamic::initialize() { return 0; }
void DeformableConvPluginDynamic::terminate() {}
size_t DeformableConvPluginDynamic::getSerializationSize() const {
return sizeof(mStride) + sizeof(mPadding) + sizeof(mDilation) +
sizeof(mDeformableGroup) + sizeof(mGroup) + sizeof(mIm2colStep);
}
void DeformableConvPluginDynamic::serialize(void *buffer) const {
serialize_value(&buffer, mStride);
serialize_value(&buffer, mPadding);
serialize_value(&buffer, mDilation);
serialize_value(&buffer, mDeformableGroup);
serialize_value(&buffer, mGroup);
serialize_value(&buffer, mIm2colStep);
}
void DeformableConvPluginDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}
void DeformableConvPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
const char *DeformableConvPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}
////////////////////// creator /////////////////////////////
DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() {
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("padding"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("groups"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("bias"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("im2col_step"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *DeformableConvPluginDynamicCreator::getPluginName() const {
return PLUGIN_NAME;
}
const char *DeformableConvPluginDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection *
DeformableConvPluginDynamicCreator::getFieldNames() {
return &mFC;
}
nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
nvinfer1::Dims stride{2, {1, 1}};
nvinfer1::Dims padding{2, {0, 0}};
nvinfer1::Dims dilation{2, {1, 1}};
int deformableGroup = 1;
int group = 1;
int im2col_step = 32;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("stride") == 0) {
stride.nbDims = 2;
stride.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
if (fc->fields[i].length == 1) {
stride.d[1] = stride.d[0];
} else {
stride.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
}
if (field_name.compare("padding") == 0) {
padding.nbDims = 2;
padding.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
if (fc->fields[i].length == 1) {
padding.d[1] = padding.d[0];
} else {
padding.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
}
if (field_name.compare("dilation") == 0) {
dilation.nbDims = 2;
dilation.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
if (fc->fields[i].length == 1) {
dilation.d[1] = dilation.d[0];
} else {
dilation.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
}
if (field_name.compare("deformable_group") == 0) {
deformableGroup = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("group") == 0) {
group = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("im2col_step") == 0) {
im2col_step = static_cast<const int *>(fc->fields[i].data)[0];
}
}
DeformableConvPluginDynamic *plugin = new DeformableConvPluginDynamic(
name, stride, padding, dilation, deformableGroup, group, im2col_step);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
void DeformableConvPluginDynamicCreator::setPluginNamespace(
const char *libNamespace) {
mNamespace = libNamespace;
}
const char *DeformableConvPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include "common_cuda_helper.hpp"
#include "deform_conv_cuda_kernel.cuh"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
template <typename T>
void trt_deformable_im2col(const T* data_input, const T* data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
T* data_col, cudaStream_t stream) {
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
deformable_im2col_gpu_kernel<T>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
num_kernels, data_input, data_offset, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels,
deformable_group, height_col, width_col, data_col);
cudaCheckError();
}
// used to switch gemm between fp32 and fp16
template <typename scalar_t>
cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const scalar_t* alpha, const scalar_t* A, int lda,
const scalar_t* B, int ldb, const scalar_t* beta,
scalar_t* C, int ldc) {
return CUBLAS_STATUS_INTERNAL_ERROR;
}
template <>
cublasStatus_t cublasGemmWrap<float>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const float* alpha, const float* A,
int lda, const float* B, int ldb,
const float* beta, float* C, int ldc) {
cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc);
}
template <>
cublasStatus_t cublasGemmWrap<half>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const half* alpha, const half* A,
int lda, const half* B, int ldb,
const half* beta, half* C, int ldc) {
cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc);
}
template <typename scalar_t>
void DeformConvForwardCUDAKernelLauncher(
const scalar_t* input, const scalar_t* weight, const scalar_t* offset,
scalar_t* output, void* workspace, int batchSize, int nInputPlane,
int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, int im2col_step, cublasHandle_t cublas_handle,
cudaStream_t stream) {
size_t word_size = sizeof(scalar_t);
im2col_step = std::min(int(batchSize), im2col_step);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long long columns_size =
mmcv::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight *
outputWidth * word_size);
// column buffer for img2col
scalar_t* columns = (scalar_t*)workspace;
workspace = workspace + columns_size;
scalar_t* output_buffer;
long long output_buffer_size = 0;
if (im2col_step == 1) {
output_buffer = output;
} else {
// output need permute when im2col_step!=1
output_buffer = (scalar_t*)workspace;
output_buffer_size = batchSize * nOutputPlane * outputWidth * outputHeight;
}
long long input_elt_step =
im2col_step * nInputPlane * inputHeight * inputWidth;
long long offset_elt_step =
im2col_step * deformable_group * 2 * kH * kW * outputHeight * outputWidth;
long long out_buffer_step =
nOutputPlane * im2col_step * outputHeight * outputWidth;
long long col_g_step =
nInputPlane * kW * kH / group * im2col_step * outputHeight * outputWidth;
long long weight_g_step =
nOutputPlane / group * nInputPlane / group * kH * kW;
long long out_buffer_g_step =
nOutputPlane / group * im2col_step * outputHeight * outputWidth;
int m = nOutputPlane / group;
int n = im2col_step * outputHeight * outputWidth;
int k = nInputPlane / group * kH * kW;
scalar_t alpha = 1.;
scalar_t beta = 0.;
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
const scalar_t* input_start = input + elt * input_elt_step;
const scalar_t* offset_start = offset + elt * offset_elt_step;
trt_deformable_im2col<scalar_t>(input_start, offset_start, nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW,
dH, dW, dilationH, dilationW, im2col_step,
deformable_group, columns, stream);
for (int g = 0; g < group; ++g) {
const scalar_t* weight_start = weight + g * weight_g_step;
scalar_t* col_start = columns + g * col_g_step;
scalar_t* out_buffer_start =
output_buffer + elt * out_buffer_step + g * out_buffer_g_step;
cublasGemmWrap<scalar_t>(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
&alpha, col_start, n, weight_start, k, &beta,
out_buffer_start, n);
cudaCheckError();
}
}
if (im2col_step != 1) {
int output_buffer_shape[5] = {batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth};
int output_buffer_permute[5] = {0, 2, 1, 3, 4};
memcpyPermute<scalar_t>(output, output_buffer, &output_buffer_shape[0],
&output_buffer_permute[0], 5, stream);
}
}
void DeformConvForwardCUDAKernelLauncher_float(
const float* input, const float* weight, const float* offset, float* output,
void* workspace, int batchSize, int nInputPlane, int inputHeight,
int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW,
int padH, int dilationW, int dilationH, int group, int deformable_group,
int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) {
DeformConvForwardCUDAKernelLauncher<float>(
input, weight, offset, output, workspace, batchSize, nInputPlane,
inputHeight, inputWidth, nOutputPlane, kW, kH, dW, dH, padW, padH,
dilationW, dilationH, group, deformable_group, im2col_step, cublas_handle,
stream);
}
#include "trt_plugin.hpp" #include "trt_plugin.hpp"
#include "trt_deform_conv.hpp"
#include "trt_nms.hpp" #include "trt_nms.hpp"
#include "trt_roi_align.hpp" #include "trt_roi_align.hpp"
#include "trt_scatternd.hpp" #include "trt_scatternd.hpp"
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
......
...@@ -13,4 +13,18 @@ ...@@ -13,4 +13,18 @@
} \ } \
} }
/**
* Returns a view of the original tensor with its dimensions permuted.
*
* @param[out] dst pointer to the destination tensor
* @param[in] src pointer to the source tensor
* @param[in] src_size shape of the src tensor
* @param[in] permute The desired ordering of dimensions
* @param[in] src_dim dim of src tensor
* @param[in] stream cuda stream handle
*/
template <class scalar_t>
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
int *permute, int src_dim, cudaStream_t stream = 0);
#endif // TRT_CUDA_HELPER_HPP #endif // TRT_CUDA_HELPER_HPP
#ifndef TRT_DEFORM_CONV_HPP
#define TRT_DEFORM_CONV_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
DeformableConvPluginDynamic(const std::string &name,
const nvinfer1::Dims &stride,
const nvinfer1::Dims &padding,
const nvinfer1::Dims &dilation,
const int deformableGroup, const int group,
int im2colStep);
DeformableConvPluginDynamic(const std::string name, const void *data,
size_t length);
DeformableConvPluginDynamic() = delete;
~DeformableConvPluginDynamic();
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const override;
// IPluginV2 Methods
const char *getPluginType() const override;
const char *getPluginVersion() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void *buffer) const override;
void destroy() override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
private:
const std::string mLayerName;
std::string mNamespace;
nvinfer1::Dims mStride;
nvinfer1::Dims mPadding;
nvinfer1::Dims mDilation;
int mDeformableGroup;
int mGroup;
int mIm2colStep;
cublasHandle_t m_cublas_handle;
cudaStream_t m_cuda_stream;
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};
class DeformableConvPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
DeformableConvPluginDynamicCreator();
const char *getPluginName() const override;
const char *getPluginVersion() const override;
const nvinfer1::PluginFieldCollection *getFieldNames() override;
nvinfer1::IPluginV2 *createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif // TRT_DEFORM_CONV_HPP
...@@ -32,7 +32,7 @@ class DeformConv2dFunction(Function): ...@@ -32,7 +32,7 @@ class DeformConv2dFunction(Function):
bias=False, bias=False,
im2col_step=32): im2col_step=32):
return g.op( return g.op(
'MMCVDeformConv2d', 'mmcv::MMCVDeformConv2d',
input, input,
offset, offset,
weight, weight,
......
...@@ -326,3 +326,79 @@ def test_scatternd(): ...@@ -326,3 +326,79 @@ def test_scatternd():
if os.path.exists(trt_file): if os.path.exists(trt_file):
os.remove(trt_file) os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results) assert torch.allclose(pytorch_results, trt_results)
def test_deform_conv():
try:
from mmcv.ops import DeformConv2dPack
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
[[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]],
[[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]],
[[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]]
offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7]
deform_weight = [[[0.4, 0.2, 0.1, 0.9]]]
c_in = 1
c_out = 1
x = torch.Tensor(input).cuda()
x.requires_grad = True
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0)
model.conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
model.conv_offset.bias.data = torch.nn.Parameter(
torch.Tensor(offset_bias).reshape(8))
model.weight.data = torch.nn.Parameter(
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
model.cuda().eval()
input_names = ['input']
output_names = ['output']
with torch.no_grad():
torch.onnx.export(
model, (x.clone(), ),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=input_names,
output_names=output_names,
opset_version=11)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'input': [list(x.shape), list(x.shape),
list(x.shape)],
}
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({'input': x.clone()})
trt_results = trt_outputs['output']
# compute pytorch_output
with torch.no_grad():
pytorch_results = model(x.clone())
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment