Unverified Commit 9d1436fb authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Feature] add cummax/cummin tensorrt plugin (#1031)

* add cummax/cummin tensorrt plugin

* fix isort

* fix with clang-format

* fix with clang-format again

* add document
parent 55b4847a
...@@ -33,6 +33,18 @@ ...@@ -33,6 +33,18 @@
- [Inputs](#inputs-4) - [Inputs](#inputs-4)
- [Outputs](#outputs-4) - [Outputs](#outputs-4)
- [Type Constraints](#type-constraints-4) - [Type Constraints](#type-constraints-4)
- [cummax](#cummax)
- [Description](#description-5)
- [Parameters](#parameters-5)
- [Inputs](#inputs-5)
- [Outputs](#outputs-5)
- [Type Constraints](#type-constraints-5)
- [cummin](#cummin)
- [Description](#description-6)
- [Parameters](#parameters-6)
- [Inputs](#inputs-6)
- [Outputs](#outputs-6)
- [Type Constraints](#type-constraints-6)
<!-- TOC --> <!-- TOC -->
...@@ -227,3 +239,67 @@ Perform sample from `input` with pixel locations from `grid`. ...@@ -227,3 +239,67 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints ### Type Constraints
- T:tensor(float32, Linear) - T:tensor(float32, Linear)
## cummax
### Description
Returns a namedtuple (`values`, `indices`) where `values` is the cumulative maximum of elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`.
### Parameters
| Type | Parameter | Description |
| ----- | --------- | --------------------------------------- |
| `int` | `dim` | The dimension to do the operation over. |
### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>The input tensor.</dd>
</dl>
### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Output values.</dd>
<dt><tt>outputs[1]</tt>: (int32, Linear)</dt>
<dd>Output indices.</dd>
</dl>
### Type Constraints
- T:tensor(float32, Linear)
## cummin
### Description
Returns a namedtuple (`values`, `indices`) where `values` is the cumulative minimum of elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`.
### Parameters
| Type | Parameter | Description |
| ----- | --------- | --------------------------------------- |
| `int` | `dim` | The dimension to do the operation over. |
### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>The input tensor.</dd>
</dl>
### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Output values.</dd>
<dt><tt>outputs[1]</tt>: (int32, Linear)</dt>
<dd>Output indices.</dd>
</dl>
### Type Constraints
- T:tensor(float32, Linear)
...@@ -30,7 +30,9 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u ...@@ -30,7 +30,9 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 | | ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 | | NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 | | MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | master | | grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
Notes Notes
......
#include "trt_cummaxmin.hpp"
#include <assert.h>
#include "trt_serialize.hpp"
void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
int *output_index, const int *dims,
int nbDims, int cum_dim, int cum_type,
cudaStream_t stream);
void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
int *output_index, const int *dims,
int nbDims, int cum_dim, int cum_type,
cudaStream_t stream);
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *CUMMAXMIN_PLUGIN_NAME{"cummaxmin"};
static const char *CUMMAX_PLUGIN_NAME{"cummax"};
static const char *CUMMIN_PLUGIN_NAME{"cummin"};
} // namespace
CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string &name, int dim,
TRT_CUMCMPTYPE cumType)
: mLayerName(name), mDim(dim), mCumType(cumType) {}
CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string name,
const void *data, size_t length)
: mLayerName(name) {
deserialize_value(&data, &length, &mDim);
deserialize_value(&data, &length, &mCumType);
}
CumMaxMinPluginDynamic::~CumMaxMinPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt *CumMaxMinPluginDynamic::clone() const {
CumMaxMinPluginDynamic *plugin =
new CumMaxMinPluginDynamic(mLayerName, mDim, mCumType);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs CumMaxMinPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
return inputs[0];
}
bool CumMaxMinPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
switch (pos) {
// input[0]
case 0:
return (inOut[pos].type == nvinfer1::DataType::kFLOAT ||
inOut[pos].type == nvinfer1::DataType::kINT32) &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// output[0]
case 1:
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
// output[1]
case 2:
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
default:
return false;
}
}
void CumMaxMinPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
size_t CumMaxMinPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
}
int CumMaxMinPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
const void *input = inputs[0];
void *output_value = outputs[0];
int *output_index = (int *)outputs[1];
const int *dims = &(inputDesc[0].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims;
switch (inputDesc[0].type) {
case nvinfer1::DataType::kFLOAT:
CumMaxMinForwardLauncher_float((float *)input, (float *)output_value,
output_index, dims, nbDims, mDim,
int(mCumType), stream);
break;
case nvinfer1::DataType::kINT32:
CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value,
output_index, dims, nbDims, mDim,
int(mCumType), stream);
break;
default:
break;
}
return 0;
}
nvinfer1::DataType CumMaxMinPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
switch (index) {
case 0:
return inputTypes[0];
case 1:
return nvinfer1::DataType::kINT32;
default:
break;
}
}
// IPluginV2 Methods
const char *CumMaxMinPluginDynamic::getPluginType() const {
switch (mCumType) {
case TRT_CUMCMPTYPE::TRT_CUMMAX:
return CUMMAX_PLUGIN_NAME;
case TRT_CUMCMPTYPE::TRT_CUMMIN:
return CUMMIN_PLUGIN_NAME;
default:
return "UnknownCumType";
}
}
const char *CumMaxMinPluginDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}
int CumMaxMinPluginDynamic::getNbOutputs() const { return 2; }
int CumMaxMinPluginDynamic::initialize() { return 0; }
void CumMaxMinPluginDynamic::terminate() {}
size_t CumMaxMinPluginDynamic::getSerializationSize() const {
return sizeof(mDim) + sizeof(mCumType);
}
void CumMaxMinPluginDynamic::serialize(void *buffer) const {
serialize_value(&buffer, mDim);
serialize_value(&buffer, mCumType);
}
void CumMaxMinPluginDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}
void CumMaxMinPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
const char *CumMaxMinPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}
CumMaxMinPluginDynamicCreator::CumMaxMinPluginDynamicCreator(
TRT_CUMCMPTYPE cumType)
: mCumType(cumType) {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("dim"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *CumMaxMinPluginDynamicCreator::getPluginName() const {
return CUMMAXMIN_PLUGIN_NAME;
}
const char *CumMaxMinPluginDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection *
CumMaxMinPluginDynamicCreator::getFieldNames() {
return &mFC;
}
nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
int dim = 0;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("dim") == 0) {
dim = static_cast<const int *>(fc->fields[i].data)[0];
}
}
CumMaxMinPluginDynamic *plugin =
new CumMaxMinPluginDynamic(name, dim, mCumType);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
// This object will be deleted when the network is destroyed, which will
// call FCPluginDynamic::destroy()
auto plugin = new CumMaxMinPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
void CumMaxMinPluginDynamicCreator::setPluginNamespace(
const char *libNamespace) {
mNamespace = libNamespace;
}
const char *CumMaxMinPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
CumMaxPluginDynamicCreator::CumMaxPluginDynamicCreator()
: CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMAX) {}
const char *CumMaxPluginDynamicCreator::getPluginName() const {
return CUMMAX_PLUGIN_NAME;
}
CumMinPluginDynamicCreator::CumMinPluginDynamicCreator()
: CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMIN) {}
const char *CumMinPluginDynamicCreator::getPluginName() const {
return CUMMIN_PLUGIN_NAME;
}
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
using mmcv::TensorDesc;
template <typename scalar_t>
__global__ void cummaxmin_kernel(const scalar_t *input, scalar_t *output_value,
int *output_index, TensorDesc tensor_desc,
int cum_dim, int cum_type) {
const size_t cum_size = tensor_desc.shape[cum_dim];
const size_t cum_stride = tensor_desc.stride[cum_dim];
const size_t data_size =
tensor_desc.stride[0] * tensor_desc.shape[0] / cum_size;
CUDA_1D_KERNEL_LOOP(index, data_size) {
size_t cum_offset =
index / cum_stride * (cum_size * cum_stride) + index % cum_stride;
int cum_index = 0;
auto cum_value = input[cum_offset];
output_value[cum_offset] = cum_value;
output_index[cum_offset] = cum_index;
for (size_t cum_index_current = 1; cum_index_current < cum_size;
++cum_index_current) {
cum_offset += cum_stride;
const auto cum_value_current = input[cum_offset];
switch (cum_type) {
case 0: // max
if (cum_value_current > cum_value) {
cum_value = cum_value_current;
cum_index = cum_index_current;
}
break;
case 1: // min
if (cum_value_current < cum_value) {
cum_value = cum_value_current;
cum_index = cum_index_current;
}
break;
}
output_value[cum_offset] = cum_value;
output_index[cum_offset] = cum_index;
}
}
}
template <typename scalar_t>
void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value,
int *output_index, const int *dims, int nbDims,
int cum_dim, int cum_type, cudaStream_t stream) {
// fill tensordesc and initial
TensorDesc tensor_desc;
memset((void *)&tensor_desc, 0, sizeof(TensorDesc));
tensor_desc.dim = nbDims;
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
tensor_desc.stride[nbDims - 1] = 1;
for (int i = nbDims - 2; i >= 0; --i) {
tensor_desc.shape[i] = dims[i];
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
}
// cum dim should be larger than 0
cum_dim = cum_dim >= 0 ? cum_dim : (nbDims + cum_dim);
const int data_size =
tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim];
const int col_block = DIVUP(data_size, THREADS_PER_BLOCK);
cummaxmin_kernel<scalar_t><<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output_value, output_index, tensor_desc, cum_dim, cum_type);
}
void CumMaxMinForwardLauncher_float(const float *input, float *output_value,
int *output_index, const int *dims,
int nbDims, int cum_dim, int cum_type,
cudaStream_t stream) {
CumMaxMinForwardLauncher<float>(input, output_value, output_index, dims,
nbDims, cum_dim, cum_type, stream);
}
void CumMaxMinForwardLauncher_int32(const int *input, int *output_value,
int *output_index, const int *dims,
int nbDims, int cum_dim, int cum_type,
cudaStream_t stream) {
CumMaxMinForwardLauncher<int>(input, output_value, output_index, dims, nbDims,
cum_dim, cum_type, stream);
}
#include "trt_plugin.hpp" #include "trt_plugin.hpp"
#include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp" #include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp" #include "trt_grid_sampler.hpp"
#include "trt_nms.hpp" #include "trt_nms.hpp"
#include "trt_roi_align.hpp" #include "trt_roi_align.hpp"
#include "trt_scatternd.hpp" #include "trt_scatternd.hpp"
REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator); REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator);
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
......
#ifndef TRT_CUMMAXMIN_HPP
#define TRT_CUMMAXMIN_HPP
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
enum TRT_CUMCMPTYPE { TRT_CUMMAX = 0, TRT_CUMMIN = 1 };
// implement of cummax and cummin
class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
CumMaxMinPluginDynamic(const std::string &name, int dim,
TRT_CUMCMPTYPE cumType);
CumMaxMinPluginDynamic(const std::string name, const void *data,
size_t length);
CumMaxMinPluginDynamic() = delete;
~CumMaxMinPluginDynamic();
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const override;
// IPluginV2 Methods
const char *getPluginType() const override;
const char *getPluginVersion() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void *buffer) const override;
void destroy() override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
protected:
const std::string mLayerName;
std::string mNamespace;
int mDim;
TRT_CUMCMPTYPE mCumType;
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};
// cummax and cummin creator
class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE cumType);
const char *getPluginName() const override;
const char *getPluginVersion() const override;
const nvinfer1::PluginFieldCollection *getFieldNames() override;
nvinfer1::IPluginV2 *createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) override;
void setPluginNamespace(const char *pluginNamespace) override;
const char *getPluginNamespace() const override;
protected:
TRT_CUMCMPTYPE mCumType;
nvinfer1::PluginFieldCollection mFC;
std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
// cummax creator
class CumMaxPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
public:
CumMaxPluginDynamicCreator();
const char *getPluginName() const override;
};
// cummin creator
class CumMinPluginDynamicCreator : public CumMaxMinPluginDynamicCreator {
public:
CumMinPluginDynamicCreator();
const char *getPluginName() const override;
};
#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP
...@@ -238,7 +238,7 @@ class TRTWraper(torch.nn.Module): ...@@ -238,7 +238,7 @@ class TRTWraper(torch.nn.Module):
output_names should be the same as onnx model. output_names should be the same as onnx model.
""" """
def __init__(self, engine, input_names, output_names): def __init__(self, engine, input_names=None, output_names=None):
super(TRTWraper, self).__init__() super(TRTWraper, self).__init__()
self.engine = engine self.engine = engine
if isinstance(self.engine, str): if isinstance(self.engine, str):
...@@ -250,6 +250,11 @@ class TRTWraper(torch.nn.Module): ...@@ -250,6 +250,11 @@ class TRTWraper(torch.nn.Module):
self._register_state_dict_hook(TRTWraper._on_state_dict) self._register_state_dict_hook(TRTWraper._on_state_dict)
self.context = self.engine.create_execution_context() self.context = self.engine.create_execution_context()
# get input and output names from engine
if input_names is None or output_names is None:
names = [_ for _ in self.engine]
input_names = list(filter(self.engine.binding_is_input, names))
output_names = list(set(names) - set(input_names))
self.input_names = input_names self.input_names = input_names
self.output_names = output_names self.output_names = output_names
......
import os import os
from functools import partial from functools import partial
from typing import Callable
import numpy as np import numpy as np
import onnx import onnx
...@@ -478,3 +479,99 @@ def test_grid_sample(mode, padding_mode, align_corners): ...@@ -478,3 +479,99 @@ def test_grid_sample(mode, padding_mode, align_corners):
if os.path.exists(trt_file): if os.path.exists(trt_file):
os.remove(trt_file) os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results) assert torch.allclose(pytorch_results, trt_results)
@pytest.mark.parametrize('func', [torch.cummax, torch.cummin])
def test_cummin_cummax(func: Callable):
# Note generally `cummax` or `cummin` is exportable to ONNX
# as long as the pytorch version >= 1.5.0, since `torch.cummax`
# is only supported with torch >= 1.5.0.
# But when `cummax` or `cummin` serves as an intermediate component
# whose outputs is used as inputs for another modules, it's expected
# that pytorch version must be >= 1.7.0. Otherwise error appears like:
# `RuntimeError: tuple appears in op that does not forward tuples,
# unsupported 'kind: prim::PythonOp`.
from packaging import version
if version.parse(torch.__version__) < version.parse('1.7.0'):
pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0')
opset = 11
# register custom op `mmcv::cummax` and `mmcv::cummin`
from mmcv.onnx.symbolic import register_extra_symbolics
register_extra_symbolics(opset)
input_list = [
# arbitrary shape, e.g. 1-D, 2-D, 3-D, ...
torch.rand((2, 3, 4, 1, 5)).cuda(),
torch.rand((1)).cuda()
]
input_names = ['input']
output_names = ['output', 'indices']
for input in input_list:
ndims = input.dim()
# valid dim range is [-ndims, ndims-1]
# test for all `dim` value which is valid
for dim in range(-ndims, ndims):
cummax_func = partial(func, dim=dim)
wrapped_model = WrapFunction(cummax_func).eval().cuda()
with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=False,
input_names=input_names,
output_names=output_names,
opset_version=opset)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'input':
[list(input.shape),
list(input.shape),
list(input.shape)]
}
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
# remove ONNX model after conversion
if os.path.exists(onnx_file):
os.remove(onnx_file)
# save TensorRT model
save_trt_engine(trt_engine, trt_file)
# load and wrap TensorRT model
trt_model = TRTWraper(trt_file)
# remove trt model after loading
if os.path.exists(trt_file):
os.remove(trt_file)
# compute trt output
with torch.no_grad():
trt_results = trt_model({'input': input.contiguous().clone()})
trt_output = trt_results['output']
trt_indices = trt_results['indices']
# compute pytorch output
with torch.no_grad():
pytorch_results = wrapped_model(input.clone())
pytorch_output = pytorch_results[0]
pytorch_indices = pytorch_results[1]
torch.testing.assert_allclose(trt_output, pytorch_output)
torch.testing.assert_allclose(trt_indices, pytorch_indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment