Unverified Commit 1ee5315e authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Feature] : Add NonMaxSuppression TensorRT Plugin (#787)

* start trt plugin prototype

* Add test module, modify roialign convertor

* finish roi_align trt plugin

* fix conflict of RoiAlign and MMCVRoiAlign

* fix for lint

* fix test tensorrt module

* test_tensorrt move import to test func

* add except error type

* add tensorrt to setup.cfg

* code format with yapf

* fix for clang-format

* move tensorrt_utils to mmcv/tensorrt, add comments, better test module

* fix line endings, docformatter

* isort init, remove trailing whitespace

* add except type

* fix setup.py

* put import extension inside trt setup

* change c++ guard, update pytest script, better setup, etc

* sort import with isort

* sort import with isort

* move init of plugin lib to init_plugins.py

* add scatternd, nms plugin (WIP)

* fix bugs of trt_nms

* add trt nms test module

* fix bugs of scatternd

* code optimize, add comment about nms kernel

* fix transform_if bug of t...
parent 8735815a
#ifndef NMS_CUDA_KERNEL_CUH #ifndef NMS_CUDA_KERNEL_CUH
#define NMS_CUDA_KERNEL_CUH #define NMS_CUDA_KERNEL_CUH
#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS #ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp" #include "parrots_cuda_helper.hpp"
#else #else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp"
#endif #endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long int) * 8; int const threadsPerBlock = sizeof(unsigned long long int) * 8;
......
#include "trt_nms.hpp"
#include <assert.h>
#include <stdio.h>
#include <chrono>
#include "trt_serialize.hpp"
extern size_t get_onnxnms_workspace_size(
size_t num_batches, size_t spatial_dimension, size_t num_classes,
size_t boxes_word_size, int center_point_box, size_t output_length);
extern void TRTNMSCUDAKernelLauncher_float(
const float *boxes, const float *scores,
const int max_output_boxes_per_class, const float iou_threshold,
const float score_threshold, const int offset, int *output,
int center_point_box, int num_batches, int spatial_dimension,
int num_classes, size_t output_length, void *workspace,
cudaStream_t stream);
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"NonMaxSuppression"};
} // namespace
nvinfer1::PluginFieldCollection NonMaxSuppressionDynamicCreator::mFC{};
std::vector<nvinfer1::PluginField>
NonMaxSuppressionDynamicCreator::mPluginAttributes;
NonMaxSuppressionDynamic::NonMaxSuppressionDynamic(
const std::string &name, int centerPointBox, int maxOutputBoxesPerClass,
float iouThreshold, float scoreThreshold, int offset)
: mLayerName(name),
mCenterPointBox(centerPointBox),
mMaxOutputBoxesPerClass(maxOutputBoxesPerClass),
mIouThreshold(iouThreshold),
mScoreThreshold(scoreThreshold),
mOffset(offset) {}
NonMaxSuppressionDynamic::NonMaxSuppressionDynamic(const std::string name,
const void *data,
size_t length)
: mLayerName(name) {
deserialize_value(&data, &length, &mCenterPointBox);
deserialize_value(&data, &length, &mMaxOutputBoxesPerClass);
deserialize_value(&data, &length, &mIouThreshold);
deserialize_value(&data, &length, &mScoreThreshold);
deserialize_value(&data, &length, &mOffset);
}
nvinfer1::IPluginV2DynamicExt *NonMaxSuppressionDynamic::clone() const {
NonMaxSuppressionDynamic *plugin = new NonMaxSuppressionDynamic(
mLayerName, mCenterPointBox, mMaxOutputBoxesPerClass, mIouThreshold,
mScoreThreshold, mOffset);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs NonMaxSuppressionDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
auto num_batches = inputs[0].d[0];
auto spatial_dimension = inputs[0].d[1];
if (mMaxOutputBoxesPerClass > 0) {
spatial_dimension = exprBuilder.operation(
nvinfer1::DimensionOperation::kMIN, *spatial_dimension,
*exprBuilder.constant(mMaxOutputBoxesPerClass));
}
auto num_classes = inputs[1].d[1];
ret.d[0] = exprBuilder.operation(
nvinfer1::DimensionOperation::kPROD, *num_batches,
*exprBuilder.operation(nvinfer1::DimensionOperation::kPROD,
*spatial_dimension, *num_classes));
ret.d[1] = exprBuilder.constant(3);
return ret;
}
bool NonMaxSuppressionDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
if (pos < nbInputs) {
switch (pos) {
case 0:
// boxes
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
case 1:
// scores
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
default:
return true;
}
} else {
switch (pos - nbInputs) {
case 0:
// selected_indices
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
default:
return true;
}
}
return true;
}
void NonMaxSuppressionDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
size_t NonMaxSuppressionDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
size_t boxes_word_size = mmcv::getElementSize(inputs[0].type);
size_t num_batches = inputs[0].dims.d[0];
size_t spatial_dimension = inputs[0].dims.d[1];
size_t num_classes = inputs[1].dims.d[1];
size_t output_length = outputs[0].dims.d[0];
return get_onnxnms_workspace_size(num_batches, spatial_dimension, num_classes,
boxes_word_size, mCenterPointBox,
output_length);
}
int NonMaxSuppressionDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
int num_batches = inputDesc[0].dims.d[0];
int spatial_dimension = inputDesc[0].dims.d[1];
int num_classes = inputDesc[1].dims.d[1];
int output_length = outputDesc[0].dims.d[0];
const float *boxes = (const float *)inputs[0];
const float *scores = (const float *)inputs[1];
int *output = (int *)outputs[0];
TRTNMSCUDAKernelLauncher_float(
boxes, scores, mMaxOutputBoxesPerClass, mIouThreshold, mScoreThreshold,
mOffset, output, mCenterPointBox, num_batches, spatial_dimension,
num_classes, output_length, workSpace, stream);
return 0;
}
nvinfer1::DataType NonMaxSuppressionDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return nvinfer1::DataType::kINT32;
}
// IPluginV2 Methods
const char *NonMaxSuppressionDynamic::getPluginType() const {
return PLUGIN_NAME;
}
const char *NonMaxSuppressionDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}
int NonMaxSuppressionDynamic::getNbOutputs() const { return 1; }
int NonMaxSuppressionDynamic::initialize() { return 0; }
void NonMaxSuppressionDynamic::terminate() {}
size_t NonMaxSuppressionDynamic::getSerializationSize() const {
return sizeof(mCenterPointBox) + sizeof(mMaxOutputBoxesPerClass) +
sizeof(mIouThreshold) + sizeof(mScoreThreshold) + sizeof(mOffset);
}
void NonMaxSuppressionDynamic::serialize(void *buffer) const {
serialize_value(&buffer, mCenterPointBox);
serialize_value(&buffer, mMaxOutputBoxesPerClass);
serialize_value(&buffer, mIouThreshold);
serialize_value(&buffer, mScoreThreshold);
serialize_value(&buffer, mOffset);
}
void NonMaxSuppressionDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}
void NonMaxSuppressionDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
const char *NonMaxSuppressionDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}
////////////////////// creator /////////////////////////////
NonMaxSuppressionDynamicCreator::NonMaxSuppressionDynamicCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("center_point_box"));
mPluginAttributes.emplace_back(
nvinfer1::PluginField("max_output_boxes_per_class"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("iou_threshold"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("score_threshold"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("offset"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *NonMaxSuppressionDynamicCreator::getPluginName() const {
return PLUGIN_NAME;
}
const char *NonMaxSuppressionDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection *
NonMaxSuppressionDynamicCreator::getFieldNames() {
return &mFC;
}
nvinfer1::IPluginV2 *NonMaxSuppressionDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
int centerPointBox = 0;
int maxOutputBoxesPerClass = 0;
float iouThreshold = 0.0f;
float scoreThreshold = 0.0f;
int offset = 0;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("center_point_box") == 0) {
centerPointBox = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("max_output_boxes_per_class") == 0) {
maxOutputBoxesPerClass = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("iou_threshold") == 0) {
iouThreshold = static_cast<const float *>(fc->fields[i].data)[0];
}
if (field_name.compare("score_threshold") == 0) {
scoreThreshold = static_cast<const float *>(fc->fields[i].data)[0];
}
if (field_name.compare("offset") == 0) {
offset = static_cast<const int *>(fc->fields[i].data)[0];
}
}
NonMaxSuppressionDynamic *plugin =
new NonMaxSuppressionDynamic(name, centerPointBox, maxOutputBoxesPerClass,
iouThreshold, scoreThreshold, offset);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *NonMaxSuppressionDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
auto plugin = new NonMaxSuppressionDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
void NonMaxSuppressionDynamicCreator::setPluginNamespace(
const char *libNamespace) {
mNamespace = libNamespace;
}
const char *NonMaxSuppressionDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
#include <stdio.h>
#include <thrust/execution_policy.h>
#include <thrust/gather.h>
#include <thrust/sort.h>
#include <thrust/transform.h>
#include <chrono>
#include <thread>
#include <vector>
#include "common_cuda_helper.hpp"
#include "nms_cuda_kernel.cuh"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
struct NMSBox {
float box[4];
};
struct nms_centerwh2xyxy {
__host__ __device__ NMSBox operator()(const NMSBox box) {
NMSBox out;
out.box[0] = box.box[0] - box.box[2] / 2.0f;
out.box[1] = box.box[1] - box.box[3] / 2.0f;
out.box[2] = box.box[0] + box.box[2] / 2.0f;
out.box[3] = box.box[1] + box.box[3] / 2.0f;
return out;
}
};
struct nms_sbox_idle {
const float* idle_box_;
__host__ __device__ nms_sbox_idle(const float* idle_box) {
idle_box_ = idle_box;
}
__host__ __device__ NMSBox operator()(const NMSBox box) {
return {idle_box_[0], idle_box_[1], idle_box_[2], idle_box_[3]};
}
};
struct nms_score_threshold {
float score_threshold_;
__host__ __device__ nms_score_threshold(const float score_threshold) {
score_threshold_ = score_threshold;
}
__host__ __device__ bool operator()(const float score) {
return score < score_threshold_;
}
};
__global__ void nms_reindex_kernel(int n, int* output, int* index_cache) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int old_index = output[index * 3 + 2];
output[index * 3 + 2] = index_cache[old_index];
}
}
__global__ void mask_to_output_kernel(const unsigned long long* dev_mask,
const int* index, int* output,
int* output_count, int batch_id,
int cls_id, int spatial_dimension,
int col_blocks,
int max_output_boxes_per_class) {
extern __shared__ unsigned long long remv[];
// fill remv with 0
CUDA_1D_KERNEL_LOOP(i, col_blocks) { remv[i] = 0; }
__syncthreads();
int start = *output_count;
int out_per_class_count = 0;
for (int i = 0; i < spatial_dimension; i++) {
const int nblock = i / threadsPerBlock;
const int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
if (threadIdx.x == 0) {
output[start * 3 + 0] = batch_id;
output[start * 3 + 1] = cls_id;
output[start * 3 + 2] = index[i];
start += 1;
}
out_per_class_count += 1;
if (out_per_class_count >= max_output_boxes_per_class) {
break;
}
__syncthreads();
// set every overlap box with bit 1 in remv
const unsigned long long* p = dev_mask + i * col_blocks;
CUDA_1D_KERNEL_LOOP(j, col_blocks) {
if (j >= nblock) {
remv[j] |= p[j];
}
} // j
__syncthreads();
}
} // i
if (threadIdx.x == 0) {
*output_count = start;
}
}
size_t get_onnxnms_workspace_size(size_t num_batches, size_t spatial_dimension,
size_t num_classes, size_t boxes_word_size,
int center_point_box, size_t output_length) {
size_t boxes_xyxy_workspace = 0;
if (center_point_box == 1) {
boxes_xyxy_workspace = mmcv::getAlignedSize(
num_batches * spatial_dimension * 4 * boxes_word_size);
}
size_t scores_workspace =
mmcv::getAlignedSize(spatial_dimension * boxes_word_size);
size_t boxes_workspace =
mmcv::getAlignedSize(spatial_dimension * 4 * boxes_word_size);
const int col_blocks = DIVUP(spatial_dimension, threadsPerBlock);
size_t mask_workspace = mmcv::getAlignedSize(spatial_dimension * col_blocks *
sizeof(unsigned long long));
size_t index_template_workspace =
mmcv::getAlignedSize(spatial_dimension * sizeof(int));
size_t index_workspace =
mmcv::getAlignedSize(spatial_dimension * sizeof(int));
size_t count_workspace = mmcv::getAlignedSize(sizeof(int));
return scores_workspace + boxes_xyxy_workspace + boxes_workspace +
mask_workspace + index_template_workspace + index_workspace +
count_workspace;
}
/**
* Launch the NonMaxSuppression kernel
*
* The NMS will be performed on each batch/class, share the kernel implement
* `nms_cuda`. For each batch/class, the `boxes_sorted` and `index_cache` will
* be sorted by scores, boxes_sorted will be used in `nms_cuda` kernel. After
* that, the output would be generated by `mask_to_output_kernel` with
* `dev_mask` and `sorted_cache`.
*
* @param[in] bboxes with shape [num_batch, spatial_dimension, 4], input boxes
* @param[in] scores with shape [num_batch, num_classes, spatial_dimension],
* input scores
* @param[in] max_output_boxes_per_class max output boxes per class
* @param[in] iou_threshold threshold of iou
* @param[in] score_threshold threshold of scores
* @param[in] offset box offset, only 0 or 1 is valid
* @param[out] output with shape [output_length, 3], each row contain index
* (batch_id, class_id, boxes_id), filling -1 if result is not vaild.
* @param[in] center_point_box 0 if boxes is [left, top, right, bottom] 1 if
* boxes is [center_x, center_y, width, height]
* @param[in] num_batches batch size of boxes and scores
* @param[in] spatial_dimension boxes numbers each batch
* @param[in] num_classes class numbers
* @param[in] output_length the max output rows
* @param[in] workspace memory for all temporary variables.
* @param[in] stream cuda stream
*/
void TRTNMSCUDAKernelLauncher_float(const float* boxes, const float* scores,
const int max_output_boxes_per_class,
const float iou_threshold,
const float score_threshold,
const int offset, int* output,
int center_point_box, int num_batches,
int spatial_dimension, int num_classes,
size_t output_length, void* workspace,
cudaStream_t stream) {
const int col_blocks = DIVUP(spatial_dimension, threadsPerBlock);
float* boxes_sorted = (float*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(spatial_dimension * 4 * sizeof(float));
float* boxes_xyxy = nullptr;
if (center_point_box == 1) {
boxes_xyxy = (float*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(num_batches * spatial_dimension * 4 *
sizeof(float));
thrust::transform(thrust::cuda::par.on(stream), (NMSBox*)boxes,
(NMSBox*)(boxes + num_batches * spatial_dimension * 4),
(NMSBox*)boxes_xyxy, nms_centerwh2xyxy());
cudaCheckError();
}
float* scores_sorted = (float*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(spatial_dimension * sizeof(float));
unsigned long long* dev_mask = (unsigned long long*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(spatial_dimension * col_blocks *
sizeof(unsigned long long));
int* index_cache = (int*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(spatial_dimension * sizeof(int));
// generate sequence [0,1,2,3,4 ....]
int* index_template = (int*)workspace;
workspace = static_cast<char*>(workspace) +
mmcv::getAlignedSize(spatial_dimension * sizeof(int));
thrust::sequence(thrust::cuda::par.on(stream), index_template,
index_template + spatial_dimension, 0);
int max_output_boxes_per_class_cpu = max_output_boxes_per_class;
if (max_output_boxes_per_class_cpu <= 0) {
max_output_boxes_per_class_cpu = spatial_dimension;
}
int* output_count = (int*)workspace;
workspace = static_cast<char*>(workspace) + mmcv::getAlignedSize(sizeof(int));
cudaMemsetAsync(output_count, 0, sizeof(int), stream);
// fill output with -1
thrust::fill(thrust::cuda::par.on(stream), output, output + output_length * 3,
-1);
cudaCheckError();
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
for (int cls_id = 0; cls_id < num_classes; ++cls_id) {
const int batch_cls_id = batch_id * num_classes + cls_id;
// sort boxes by score
cudaMemcpyAsync(scores_sorted, scores + batch_cls_id * spatial_dimension,
spatial_dimension * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
cudaCheckError();
cudaMemcpyAsync(index_cache, index_template,
spatial_dimension * sizeof(int), cudaMemcpyDeviceToDevice,
stream);
cudaCheckError();
thrust::sort_by_key(thrust::cuda::par.on(stream), scores_sorted,
scores_sorted + spatial_dimension, index_cache,
thrust::greater<float>());
if (center_point_box == 1) {
thrust::gather(thrust::cuda::par.on(stream), index_cache,
index_cache + spatial_dimension,
(NMSBox*)(boxes_xyxy + batch_id * spatial_dimension * 4),
(NMSBox*)boxes_sorted);
} else {
thrust::gather(thrust::cuda::par.on(stream), index_cache,
index_cache + spatial_dimension,
(NMSBox*)(boxes + batch_id * spatial_dimension * 4),
(NMSBox*)boxes_sorted);
}
cudaCheckError();
if (score_threshold > 0.0f) {
thrust::transform_if(
thrust::cuda::par.on(stream), (NMSBox*)boxes_sorted,
(NMSBox*)(boxes_sorted + spatial_dimension * 4), scores_sorted,
(NMSBox*)boxes_sorted, nms_sbox_idle(boxes_sorted),
nms_score_threshold(score_threshold));
}
nms_cuda<<<blocks, threads, 0, stream>>>(spatial_dimension, iou_threshold,
offset, boxes_sorted, dev_mask);
// will be performed when dev_mask is full.
mask_to_output_kernel<<<1, threadsPerBlock,
col_blocks * sizeof(unsigned long long),
stream>>>(
dev_mask, index_cache, output, output_count, batch_id, cls_id,
spatial_dimension, col_blocks, max_output_boxes_per_class_cpu);
} // cls_id
} // batch_id
}
#include "trt_plugin.hpp" #include "trt_plugin.hpp"
#include "trt_nms.hpp"
#include "trt_roi_align.hpp" #include "trt_roi_align.hpp"
#include "trt_scatternd.hpp" #include "trt_scatternd.hpp"
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
......
...@@ -9,7 +9,6 @@ void TRTRoIAlignForwardCUDAKernelLauncher( ...@@ -9,7 +9,6 @@ void TRTRoIAlignForwardCUDAKernelLauncher(
scalar_t spatial_scale, int sampling_ratio, int pool_mode, bool aligned, scalar_t spatial_scale, int sampling_ratio, int pool_mode, bool aligned,
cudaStream_t stream) { cudaStream_t stream) {
roi_align_forward_cuda_kernel<scalar_t> roi_align_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input, rois, output, argmax_y, argmax_x, aligned_height, output_size, input, rois, output, argmax_y, argmax_x, aligned_height,
aligned_width, static_cast<scalar_t>(spatial_scale), sampling_ratio, aligned_width, static_cast<scalar_t>(spatial_scale), sampling_ratio,
......
#ifndef TRT_NMS_HPP
#define TRT_NMS_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
class NonMaxSuppressionDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
NonMaxSuppressionDynamic(const std::string &name, int centerPointBox,
int maxOutputBoxesPerClass, float iouThreshold,
float scoreThreshold, int offset);
NonMaxSuppressionDynamic(const std::string name, const void *data,
size_t length);
NonMaxSuppressionDynamic() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const override;
// IPluginV2 Methods
const char *getPluginType() const override;
const char *getPluginVersion() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void *buffer) const override;
void destroy() override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
private:
const std::string mLayerName;
std::string mNamespace;
int mCenterPointBox;
int mMaxOutputBoxesPerClass;
float mIouThreshold;
float mScoreThreshold;
int mOffset;
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};
class NonMaxSuppressionDynamicCreator : public nvinfer1::IPluginCreator {
public:
NonMaxSuppressionDynamicCreator();
const char *getPluginName() const override;
const char *getPluginVersion() const override;
const nvinfer1::PluginFieldCollection *getFieldNames() override;
nvinfer1::IPluginV2 *createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif // TRT_NMS_HPP
...@@ -31,5 +31,11 @@ inline unsigned int getElementSize(nvinfer1::DataType t) { ...@@ -31,5 +31,11 @@ inline unsigned int getElementSize(nvinfer1::DataType t) {
throw std::runtime_error("Invalid DataType."); throw std::runtime_error("Invalid DataType.");
return 0; return 0;
} }
inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) {
return size_t((origin_size + aligned_number - 1) / aligned_number) *
aligned_number;
}
} // namespace mmcv } // namespace mmcv
#endif // TRT_PLUGIN_HELPER_HPP #endif // TRT_PLUGIN_HELPER_HPP
import os
import sys import sys
import numpy as np import numpy as np
...@@ -23,7 +24,9 @@ class NMSop(torch.autograd.Function): ...@@ -23,7 +24,9 @@ class NMSop(torch.autograd.Function):
def symbolic(g, bboxes, scores, iou_threshold, offset): def symbolic(g, bboxes, scores, iou_threshold, offset):
from ..onnx import is_custom_op_loaded from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded() has_custom_op = is_custom_op_loaded()
if has_custom_op: # TensorRT nms plugin is aligned with original nms in ONNXRuntime
is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
if has_custom_op and (not is_trt_backend):
return g.op( return g.op(
'mmcv::NonMaxSuppression', 'mmcv::NonMaxSuppression',
bboxes, bboxes,
...@@ -295,7 +298,11 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -295,7 +298,11 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export(): if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_) dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
boxes = boxes[keep] boxes = boxes[keep]
scores = dets[:, -1] # -1 indexing works abnormal in TensorRT
# This assumes `dets` has 5 dimensions where
# the last dimension is score.
# TODO: more elegant way to handle the dimension issue.
scores = dets[:, 4]
else: else:
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
for id in torch.unique(idxs): for id in torch.unique(idxs):
......
import numpy as np
import onnx
import tensorrt as trt import tensorrt as trt
import torch import torch
def preprocess_onnx(onnx_model):
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.
This function perform preprocess on the onnx model to solve the conflicts.
For example, onnx `attribute` is loaded in TensorRT on host and onnx
`input` is loaded on device. The shape inference is performed on host, so
any `input` related to shape (such as `max_output_boxes_per_class` in
NonMaxSuppression) should be transformed to `attribute` before conversion.
Arguments:
onnx_model (onnx.ModelProto): Input onnx model.
Returns:
onnx.ModelProto: Modified onnx model.
"""
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
node_dict = {}
for node in nodes:
node_outputs = node.output
for output in node_outputs:
if len(output) > 0:
node_dict[output] = node
init_dict = {_.name: _ for _ in initializers}
def parse_data(name, typ):
if name in node_dict:
const_node = node_dict[name]
assert const_node.op_type == 'Constant'
raw_data = const_node.attribute[0].t.raw_data
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initilizer.')
return np.frombuffer(raw_data, typ).item()
nrof_node = len(nodes)
for idx in range(nrof_node):
node = nodes[idx]
node_attributes = node.attribute
node_inputs = node.input
node_outputs = node.output
node_name = node.name
# process NonMaxSuppression node
if node.op_type == 'NonMaxSuppression':
center_point_box = 0
max_output_boxes_per_class = 1000000
iou_threshold = 0.3
score_threshold = 0.0
offset = 0
for attribute in node_attributes:
if attribute.name == 'center_point_box':
center_point_box = attribute.i
elif attribute.name == 'offset':
offset = attribute.i
if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64)
if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32)
if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)
new_node = onnx.helper.make_node(
'NonMaxSuppression',
node_inputs[:2],
node_outputs,
name=node_name,
center_point_box=center_point_box,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
offset=offset)
for output in node_outputs:
if output in node_dict:
node_dict[output] = new_node
nodes.insert(idx, new_node)
nodes.remove(node)
return onnx_model
def onnx2trt(onnx_model, def onnx2trt(onnx_model,
opt_shape_dict, opt_shape_dict,
log_level=trt.Logger.ERROR, log_level=trt.Logger.ERROR,
...@@ -46,10 +137,15 @@ def onnx2trt(onnx_model, ...@@ -46,10 +137,15 @@ def onnx2trt(onnx_model,
parser = trt.OnnxParser(network, logger) parser = trt.OnnxParser(network, logger)
if isinstance(onnx_model, str): if isinstance(onnx_model, str):
assert parser.parse_from_file(onnx_model), 'parse onnx failed.' onnx_model = onnx.load(onnx_model)
else:
assert parser.parse( onnx_model = preprocess_onnx(onnx_model)
onnx_model.SerializeToString()), 'parse onnx failed.'
if not parser.parse(onnx_model.SerializeToString()):
error_msgs = ''
for error in range(parser.num_errors):
error_msgs += f'{parser.get_error(error)}\n'
raise RuntimeError(f'parse onnx failed:\n{error_msgs}')
# config builder # config builder
builder.max_workspace_size = max_workspace_size builder.max_workspace_size = max_workspace_size
...@@ -188,6 +284,8 @@ class TRTWraper(torch.nn.Module): ...@@ -188,6 +284,8 @@ class TRTWraper(torch.nn.Module):
for input_name, input_tensor in inputs.items(): for input_name, input_tensor in inputs.items():
idx = self.engine.get_binding_index(input_name) idx = self.engine.get_binding_index(input_name)
if input_tensor.dtype == torch.long:
input_tensor = input_tensor.int()
self.context.set_binding_shape(idx, tuple(input_tensor.shape)) self.context.set_binding_shape(idx, tuple(input_tensor.shape))
bindings[idx] = input_tensor.contiguous().data_ptr() bindings[idx] = input_tensor.contiguous().data_ptr()
......
import os import os
from functools import partial
import numpy as np import numpy as np
import onnx import onnx
import pytest import pytest
import torch import torch
import torch.nn as nn
try: try:
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
...@@ -22,7 +24,7 @@ if not is_tensorrt_plugin_loaded(): ...@@ -22,7 +24,7 @@ if not is_tensorrt_plugin_loaded():
allow_module_level=True) allow_module_level=True)
class WrapFunction(torch.nn.Module): class WrapFunction(nn.Module):
def __init__(self, wrapped_function): def __init__(self, wrapped_function):
super(WrapFunction, self).__init__() super(WrapFunction, self).__init__()
...@@ -110,6 +112,162 @@ def test_roialign(): ...@@ -110,6 +112,162 @@ def test_roialign():
assert torch.allclose(pytorch_roi_feat, trt_roi_feat) assert torch.allclose(pytorch_roi_feat, trt_roi_feat)
def test_nms():
try:
import mmcv
from mmcv.ops import nms
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
os.environ['ONNX_BACKEND'] = 'MMCVTensorRT'
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl')
boxes = data['boxes'].cuda()
scores = data['scores'].cuda()
nms = partial(nms, iou_threshold=0.7, offset=0)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes.detach().cpu(), scores.detach().cpu()),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
output_names=['dets', 'inds'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'boxes': [list(boxes.shape),
list(boxes.shape),
list(boxes.shape)],
'scores': [list(scores.shape),
list(scores.shape),
list(scores.shape)]
}
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, ['boxes', 'scores'], ['dets', 'inds'])
with torch.no_grad():
trt_outputs = trt_model({'boxes': boxes, 'scores': scores})
trt_dets = trt_outputs['dets']
trt_inds = trt_outputs['inds']
trt_inds = trt_inds.long()
# compute pytorch_output
with torch.no_grad():
pytorch_outputs = wrapped_model(boxes, scores)
pytorch_dets, pytorch_inds = pytorch_outputs
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
num_boxes = pytorch_dets.shape[0]
trt_dets = trt_dets[:num_boxes, ...]
trt_inds = trt_inds[:num_boxes]
trt_scores = trt_dets[:, 4]
pytorch_scores = pytorch_dets[:, 4]
os.environ.pop('ONNX_BACKEND')
assert torch.allclose(pytorch_scores, trt_scores, atol=1e-3)
assert torch.equal(pytorch_inds, trt_inds)
def test_batched_nms():
try:
import mmcv
from mmcv.ops import batched_nms
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
# trt config
os.environ['ONNX_BACKEND'] = 'MMCVTensorRT'
fp16_mode = False
max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl')
nms_cfg = dict(type='nms', iou_threshold=0.7)
boxes = data['boxes'].cuda()
scores = data['scores'].cuda()
idxs = data['idxs'].cuda()
class_agnostic = False
nms = partial(batched_nms, nms_cfg=nms_cfg, class_agnostic=class_agnostic)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
input_data = (boxes.detach().cpu(), scores.detach().cpu(),
idxs.detach().cpu())
input_names = ['boxes', 'scores', 'idxs']
output_names = ['dets', 'inds']
with torch.no_grad():
torch.onnx.export(
wrapped_model,
input_data,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=input_names,
output_names=output_names,
opset_version=11)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'boxes': [list(boxes.shape),
list(boxes.shape),
list(boxes.shape)],
'scores': [list(scores.shape),
list(scores.shape),
list(scores.shape)],
'idxs': [list(idxs.shape),
list(idxs.shape),
list(idxs.shape)]
}
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({
'boxes': boxes,
'scores': scores,
'idxs': idxs
})
trt_dets = trt_outputs['dets']
trt_inds = trt_outputs['inds']
trt_inds = trt_inds.long()
# compute pytorch_output
with torch.no_grad():
pytorch_outputs = wrapped_model(boxes, scores, idxs)
pytorch_dets, pytorch_inds = pytorch_outputs
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
num_boxes = pytorch_dets.shape[0]
trt_dets = trt_dets[:num_boxes, ...]
trt_inds = trt_inds[:num_boxes]
trt_scores = trt_dets[:, 4]
pytorch_scores = pytorch_dets[:, 4]
os.environ.pop('ONNX_BACKEND')
assert torch.allclose(pytorch_scores, trt_scores)
assert torch.equal(pytorch_inds, trt_inds)
def test_scatternd(): def test_scatternd():
def func(data): def func(data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment