Unverified Commit 07d681ac authored by qjqfl's avatar qjqfl Committed by GitHub
Browse files

[Feature]: support tensorrt custom plugin `MMCVCornerPool` (#1179)

parent 2dc0a219
#include "trt_corner_pool.hpp"
#include <assert.h>
#include "trt_serialize.hpp"
void CornerPoolForwardLauncher_float(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type, cudaStream_t stream);
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *CORNER_POOL_PLUGIN_NAME{"MMCVCornerPool"};
} // namespace
CornerPoolPluginDynamic::CornerPoolPluginDynamic(const std::string &name,
TRT_CORNER_POOL_TYPE poolType)
: mLayerName(name), mPoolType(poolType) {}
CornerPoolPluginDynamic::CornerPoolPluginDynamic(const std::string name,
const void *data,
size_t length)
: mLayerName(name) {
deserialize_value(&data, &length, &mPoolType);
}
CornerPoolPluginDynamic::~CornerPoolPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt *CornerPoolPluginDynamic::clone() const {
CornerPoolPluginDynamic *plugin =
new CornerPoolPluginDynamic(mLayerName, mPoolType);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs CornerPoolPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
return inputs[0];
}
bool CornerPoolPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
switch (pos) {
// input[0]
case 0:
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// output[0]
case 1:
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
default:
return false;
}
}
void CornerPoolPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
size_t CornerPoolPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
}
int CornerPoolPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
const void *input = inputs[0];
void *output_value = outputs[0];
const int batch_size = inputDesc[0].dims.d[0];
const int channels = inputDesc[0].dims.d[1];
const int height = inputDesc[0].dims.d[2];
const int width = inputDesc[0].dims.d[3];
CornerPoolForwardLauncher_float((float *)input, (float *)output_value,
batch_size, channels, height, width,
int(mPoolType), stream);
return 0;
}
nvinfer1::DataType CornerPoolPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return inputTypes[0];
}
// IPluginV2 Methods
const char *CornerPoolPluginDynamic::getPluginType() const {
switch (mPoolType) {
case TRT_CORNER_POOL_TYPE::TRT_TOP_POOL:
case TRT_CORNER_POOL_TYPE::TRT_BOTTOM_POOL:
case TRT_CORNER_POOL_TYPE::TRT_LEFT_POOL:
case TRT_CORNER_POOL_TYPE::TRT_RIGHT_POOL:
return CORNER_POOL_PLUGIN_NAME;
default:
return "UnknownpoolType";
}
}
const char *CornerPoolPluginDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}
int CornerPoolPluginDynamic::getNbOutputs() const { return 1; }
int CornerPoolPluginDynamic::initialize() { return 0; }
void CornerPoolPluginDynamic::terminate() {}
size_t CornerPoolPluginDynamic::getSerializationSize() const {
return sizeof(mPoolType);
}
void CornerPoolPluginDynamic::serialize(void *buffer) const {
serialize_value(&buffer, mPoolType);
}
void CornerPoolPluginDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}
void CornerPoolPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
const char *CornerPoolPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}
CornerPoolPluginDynamicCreator::CornerPoolPluginDynamicCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("mode"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *CornerPoolPluginDynamicCreator::getPluginName() const {
return CORNER_POOL_PLUGIN_NAME;
}
const char *CornerPoolPluginDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection *
CornerPoolPluginDynamicCreator::getFieldNames() {
return &mFC;
}
nvinfer1::IPluginV2 *CornerPoolPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
TRT_CORNER_POOL_TYPE poolType;
int poolMode = -1;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("mode") == 0) {
poolMode = static_cast<const int *>(fc->fields[i].data)[0];
}
}
assert(poolMode >= 0 && poolMode <= 3);
switch (poolMode) {
case 0:
poolType = TRT_CORNER_POOL_TYPE::TRT_TOP_POOL;
break;
case 1:
poolType = TRT_CORNER_POOL_TYPE::TRT_BOTTOM_POOL;
break;
case 2:
poolType = TRT_CORNER_POOL_TYPE::TRT_LEFT_POOL;
break;
case 3:
poolType = TRT_CORNER_POOL_TYPE::TRT_RIGHT_POOL;
break;
default:
break;
}
CornerPoolPluginDynamic *plugin = new CornerPoolPluginDynamic(name, poolType);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *CornerPoolPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
// This object will be deleted when the network is destroyed, which will
// call FCPluginDynamic::destroy()
auto plugin = new CornerPoolPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
void CornerPoolPluginDynamicCreator::setPluginNamespace(
const char *libNamespace) {
mNamespace = libNamespace;
}
const char *CornerPoolPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
template <typename scalar_t>
__global__ void top_bottom_pool_kernel(const scalar_t *input, scalar_t *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type) {
const int nthreads = batch_size * channels * width;
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n_idx = index / (channels * width); // batch
int w_idx = index % width; // width
int c_idx = (index / width) % channels; // channels
int offset_n = n_idx * channels * width * height;
int offset_n_c = offset_n + c_idx * width * height;
int direction = -1; // in [-1, 1], default for TopPool
int index_start = height - 2; // default for TopPool
// pool_type in [0, 1]
if (pool_type == 0) {
// TopPool
// directly copy the most bottom value from input to output
output[offset_n_c + (height - 1) * width + w_idx] =
input[offset_n_c + (height - 1) * width + w_idx];
} else {
// BottomPool
// directly copy the most top value from input to output
output[offset_n_c + w_idx] = input[offset_n_c + w_idx];
index_start = 1;
direction = 1;
}
// do pool
for (int h = index_start; h >= 0 && h < height; h += direction) {
output[offset_n_c + h * width + w_idx] =
max(output[offset_n_c + (h - direction) * width + w_idx],
input[offset_n_c + h * width + w_idx]);
}
}
}
template <typename scalar_t>
__global__ void left_right_pool_kernel(const scalar_t *input, scalar_t *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type) {
const int nthreads = batch_size * channels * height;
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n_idx = index / (channels * height); // batch
int h_idx = index % height; // height
int c_idx = (index / height) % channels; // channels
int offset_n = n_idx * channels * width * height;
int offset_n_c = offset_n + c_idx * width * height;
int offset_n_c_h = offset_n_c + h_idx * width;
int direction = -1; // in [-1, 1], default for LeftPool
int index_start = width - 2; // default for LeftPool
// pool_type in [2, 3]
if (pool_type == 2) {
// LeftPool
// directly copy the most right value from input to output
output[offset_n_c_h + width - 1] = input[offset_n_c_h + width - 1];
} else {
// RightPool
// directly copy the most left value from input to output
output[offset_n_c_h] = input[offset_n_c_h];
index_start = 1;
direction = 1;
}
// do pool
for (int w = index_start; w >= 0 && w < width; w += direction) {
output[offset_n_c_h + w] =
max(output[offset_n_c_h + w - direction], input[offset_n_c_h + w]);
}
}
}
template <typename scalar_t>
void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type, cudaStream_t stream) {
int nthreads = -1, col_block = -1;
switch (pool_type) {
case 0:
case 1:
nthreads = batch_size * channels * width;
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
top_bottom_pool_kernel<scalar_t>
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output, batch_size, channels, height, width, pool_type);
break;
case 2:
case 3:
nthreads = batch_size * channels * height;
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
left_right_pool_kernel<scalar_t>
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output, batch_size, channels, height, width, pool_type);
break;
}
}
void CornerPoolForwardLauncher_float(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type, cudaStream_t stream) {
CornerPoolForwardLauncher<float>(input, output, batch_size, channels, height,
width, pool_type, stream);
}
#include "trt_plugin.hpp" #include "trt_plugin.hpp"
#include "trt_corner_pool.hpp"
#include "trt_cummaxmin.hpp" #include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp" #include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp" #include "trt_grid_sampler.hpp"
...@@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); ...@@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator); REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator);
REGISTER_TENSORRT_PLUGIN(CornerPoolPluginDynamicCreator);
extern "C" { extern "C" {
bool initLibMMCVInferPlugins() { return true; } bool initLibMMCVInferPlugins() { return true; }
......
#ifndef TRT_CORNER_POOL_HPP
#define TRT_CORNER_POOL_HPP
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
enum TRT_CORNER_POOL_TYPE {
TRT_TOP_POOL = 0,
TRT_BOTTOM_POOL = 1,
TRT_LEFT_POOL = 2,
TRT_RIGHT_POOL = 3
};
// implement of CornerPool
class CornerPoolPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
CornerPoolPluginDynamic(const std::string &name,
TRT_CORNER_POOL_TYPE poolType);
CornerPoolPluginDynamic(const std::string name, const void *data,
size_t length);
CornerPoolPluginDynamic() = delete;
~CornerPoolPluginDynamic();
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const override;
// IPluginV2 Methods
const char *getPluginType() const override;
const char *getPluginVersion() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void *buffer) const override;
void destroy() override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
protected:
const std::string mLayerName;
std::string mNamespace;
TRT_CORNER_POOL_TYPE mPoolType;
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};
// CornerPool creator
class CornerPoolPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
CornerPoolPluginDynamicCreator();
const char *getPluginName() const override;
const char *getPluginVersion() const override;
const nvinfer1::PluginFieldCollection *getFieldNames() override;
nvinfer1::IPluginV2 *createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
protected:
nvinfer1::PluginFieldCollection mFC;
std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif TRT_CORNER_POOL_HPP // TRT_CORNER_POOL_HPP
...@@ -727,3 +727,81 @@ def test_instance_norm(dynamic_export, fp16_mode): ...@@ -727,3 +727,81 @@ def test_instance_norm(dynamic_export, fp16_mode):
if os.path.exists(trt_file): if os.path.exists(trt_file):
os.remove(trt_file) os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results) assert torch.allclose(pytorch_results, trt_results)
@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
def test_corner_pool(mode):
try:
from mmcv.ops import CornerPool
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
opset = 11
# register custom op `mmcv::MMCVCornerPool`
from mmcv.onnx.symbolic import register_extra_symbolics
register_extra_symbolics(opset)
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
inputs = [
# (n, c, h, w)
torch.rand((2, 3, 5, 5)),
torch.rand((1, 2, 4, 6)),
torch.rand((2, 1, 3, 2)),
]
class CornerPoolWrapper(CornerPool):
def __init__(self, mode):
super(CornerPoolWrapper, self).__init__(mode)
def forward(self, x):
# no use `torch.cummax`, instead `corner_pool` is used
# for various torch version
return self.corner_pool.apply(x)
wrapped_model = CornerPoolWrapper(mode).cuda()
for input in inputs:
input = input.cuda()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, ),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'input': [list(input.shape),
list(input.shape),
list(input.shape)],
}
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWrapper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': input})
trt_pool_feat = trt_outputs['output']
# compute pytorch_output
with torch.no_grad():
pytorch_pool_feat = wrapped_model(input)
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_pool_feat, trt_pool_feat, atol=1e-5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment