Unverified Commit 4d42365a authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

[Feature]: add TensorRT InstanceNormalization plugin (#1034)

* add instancenorm plugin

* resolve comments

* fix lint

* fix typo
parent 6c7d6c32
...@@ -45,6 +45,12 @@ ...@@ -45,6 +45,12 @@
- [Inputs](#inputs-6) - [Inputs](#inputs-6)
- [Outputs](#outputs-6) - [Outputs](#outputs-6)
- [Type Constraints](#type-constraints-6) - [Type Constraints](#type-constraints-6)
- [MMCVInstanceNormalization](#mmcvinstancenormalization)
- [Description](#description-7)
- [Parameters](#parameters-7)
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
<!-- TOC --> <!-- TOC -->
...@@ -303,3 +309,39 @@ Returns a namedtuple (`values`, `indices`) where `values` is the cumulative mini ...@@ -303,3 +309,39 @@ Returns a namedtuple (`values`, `indices`) where `values` is the cumulative mini
### Type Constraints ### Type Constraints
- T:tensor(float32, Linear) - T:tensor(float32, Linear)
## MMCVInstanceNormalization
### Description
Carries out instance normalization as described in the paper https://arxiv.org/abs/1607.08022.
y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance are computed per instance per channel.
### Parameters
| Type | Parameter | Description |
| ------- | --------- | -------------------------------------------------------------------- |
| `float` | `epsilon` | The epsilon value to use to avoid division by zero. Default is 1e-05 |
### Inputs
<dl>
<dt><tt>input</tt>: T</dt>
<dd>Input data tensor from the previous operator; dimensions for image case are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data. For non image case, the dimensions are in the form of (N x C x D1 x D2 ... Dn), where N is the batch size.</dd>
<dt><tt>scale</tt>: T</dt>
<dd>The input 1-dimensional scale tensor of size C.</dd>
<dt><tt>B</tt>: T</dt>
<dd>The input 1-dimensional bias tensor of size C.</dd>
</dl>
### Outputs
<dl>
<dt><tt>output</tt>: T</dt>
<dd>The output tensor of the same shape as input.</dd>
</dl>
### Type Constraints
- T:tensor(float32, Linear)
...@@ -24,16 +24,16 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u ...@@ -24,16 +24,16 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
## List of TensorRT plugins supported in MMCV ## List of TensorRT plugins supported in MMCV
| ONNX Operator | TensorRT Plugin | MMCV Releases | | ONNX Operator | TensorRT Plugin | MMCV Releases |
| :---------------: | :-------------------------------------------------------------: | :-----------: | | :-----------------------: | :-----------------------------------------------------------------------------: | :-----------: |
| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 | | MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 |
| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 | | ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 | | NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 | | MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 | | grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master | | cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master | | cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |
Notes Notes
- All plugins listed above are developed on TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0 - All plugins listed above are developed on TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0
......
// Modified from:
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.cpp
#include "trt_instance_norm.hpp"
#include <cuda_fp16.h>
#include <stdexcept>
#include "trt_serialize.hpp"
using namespace nvinfer1;
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
cudnnDataType_t* cudnn_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*cudnn_dtype = CUDNN_DATA_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*cudnn_dtype = CUDNN_DATA_HALF;
break;
default:
return CUDNN_STATUS_BAD_PARAM;
}
return CUDNN_STATUS_SUCCESS;
}
namespace {
constexpr const char* PLUGIN_VERSION{"1"};
constexpr const char* PLUGIN_NAME{"MMCVInstanceNormalization"};
} // namespace
PluginFieldCollection InstanceNormalizationDynamicCreator::mFC{};
std::vector<PluginField> InstanceNormalizationDynamicCreator::mPluginAttributes;
InstanceNormalizationDynamic::InstanceNormalizationDynamic(
const std::string& name, float epsilon)
: mLayerName(name), mEpsilon(epsilon) {}
InstanceNormalizationDynamic::InstanceNormalizationDynamic(
const std::string& name, void const* serialData, size_t serialLength)
: mLayerName(name) {
deserialize_value(&serialData, &serialLength, &mEpsilon);
}
InstanceNormalizationDynamic::~InstanceNormalizationDynamic() {}
// InstanceNormalizationDynamic returns one output.
int InstanceNormalizationDynamic::getNbOutputs() const { return 1; }
DimsExprs InstanceNormalizationDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) {
nvinfer1::DimsExprs output(inputs[0]);
return output;
}
int InstanceNormalizationDynamic::initialize() { return 0; }
void InstanceNormalizationDynamic::terminate() {}
size_t InstanceNormalizationDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
int n = inputs[0].dims.d[0];
int c = inputs[0].dims.d[1];
int elem_size = mmcv::getElementSize(inputs[1].type);
return mmcv::getAlignedSize(n * c * elem_size) * 2;
}
int InstanceNormalizationDynamic::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) {
nvinfer1::Dims input_dims = inputDesc[0].dims;
int n = input_dims.d[0];
int c = input_dims.d[1];
int h = input_dims.d[2];
int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1;
int elem_size = mmcv::getElementSize(inputDesc[1].type);
void* n_scales = (void*)workspace;
void* n_bias = (void*)(workspace + mmcv::getAlignedSize(n * c * elem_size));
const void* scales = (const void*)inputs[1];
const void* bias = (const void*)inputs[2];
for (int i = 0; i < n; ++i) {
cudaMemcpyAsync(n_scales + i * c * elem_size, scales, c * elem_size,
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(n_bias + i * c * elem_size, bias, c * elem_size,
cudaMemcpyDeviceToDevice, stream);
}
cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1,
n * c, 1, 1);
cudnnDataType_t cudnn_dtype{};
convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype);
cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c,
h, w);
cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c,
h, w);
float alpha = 1;
float beta = 0;
void const* x_ptr = inputs[0];
void* y_ptr = outputs[0];
cudnnSetStream(_cudnn_handle, stream);
// Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical
// overflows (NaNs) for fp32 data in some circumstances. The lower-
// performance CUDNN_BATCHNORM_SPATIAL should be used if this is not
// acceptable.
cudnnBatchNormalizationForwardTraining(
_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, _x_desc,
x_ptr, _y_desc, y_ptr, _b_desc, n_scales, n_bias, 1., nullptr, nullptr,
mEpsilon, nullptr, nullptr);
return 0;
}
size_t InstanceNormalizationDynamic::getSerializationSize() const {
return serialized_size(mEpsilon);
}
void InstanceNormalizationDynamic::serialize(void* buffer) const {
serialize_value(&buffer, mEpsilon);
}
bool InstanceNormalizationDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) {
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
inOut[pos].type == nvinfer1::DataType::kHALF) &&
inOut[pos].format == nvinfer1::PluginFormat::kLINEAR &&
inOut[pos].type == inOut[0].type);
}
const char* InstanceNormalizationDynamic::getPluginType() const {
return PLUGIN_NAME;
}
const char* InstanceNormalizationDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}
void InstanceNormalizationDynamic::destroy() { delete this; }
IPluginV2DynamicExt* InstanceNormalizationDynamic::clone() const {
auto* plugin = new InstanceNormalizationDynamic{mLayerName, mEpsilon};
plugin->setPluginNamespace(mPluginNamespace.c_str());
return plugin;
}
// Set plugin namespace
void InstanceNormalizationDynamic::setPluginNamespace(
const char* pluginNamespace) {
mPluginNamespace = pluginNamespace;
}
const char* InstanceNormalizationDynamic::getPluginNamespace() const {
return mPluginNamespace.c_str();
}
nvinfer1::DataType InstanceNormalizationDynamic::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
return inputTypes[0];
}
// Attach the plugin object to an execution context and grant the plugin the
// access to some context resource.
void InstanceNormalizationDynamic::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext,
IGpuAllocator* gpuAllocator) {
_cudnn_handle = cudnnContext;
cudnnCreateTensorDescriptor(&_b_desc);
cudnnCreateTensorDescriptor(&_x_desc);
cudnnCreateTensorDescriptor(&_y_desc);
}
// Detach the plugin object from its execution context.
void InstanceNormalizationDynamic::detachFromContext() {
cudnnDestroyTensorDescriptor(_y_desc);
cudnnDestroyTensorDescriptor(_x_desc);
cudnnDestroyTensorDescriptor(_b_desc);
}
void InstanceNormalizationDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
// InstanceNormalizationDynamicCreator methods
InstanceNormalizationDynamicCreator::InstanceNormalizationDynamicCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(
PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* InstanceNormalizationDynamicCreator::getPluginName() const {
return PLUGIN_NAME;
}
const char* InstanceNormalizationDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}
const PluginFieldCollection*
InstanceNormalizationDynamicCreator::getFieldNames() {
return &mFC;
}
IPluginV2DynamicExt* InstanceNormalizationDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) {
float epsilon = 1e-5;
const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i) {
const char* attrName = fields[i].name;
if (!strcmp(attrName, "epsilon")) {
epsilon = *(static_cast<const float*>(fields[i].data));
}
}
InstanceNormalizationDynamic* obj =
new InstanceNormalizationDynamic(name, epsilon);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
IPluginV2DynamicExt* InstanceNormalizationDynamicCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) {
InstanceNormalizationDynamic* obj =
new InstanceNormalizationDynamic{name, serialData, serialLength};
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
void InstanceNormalizationDynamicCreator::setPluginNamespace(
const char* libNamespace) {
mNamespace = libNamespace;
}
const char* InstanceNormalizationDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "trt_cummaxmin.hpp" #include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp" #include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp" #include "trt_grid_sampler.hpp"
#include "trt_instance_norm.hpp"
#include "trt_nms.hpp" #include "trt_nms.hpp"
#include "trt_roi_align.hpp" #include "trt_roi_align.hpp"
#include "trt_scatternd.hpp" #include "trt_scatternd.hpp"
...@@ -14,6 +15,7 @@ REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); ...@@ -14,6 +15,7 @@ REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator);
extern "C" { extern "C" {
bool initLibMMCVInferPlugins() { return true; } bool initLibMMCVInferPlugins() { return true; }
......
// Modified from:
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.h
#ifndef TRT_INSTANCE_NORMALIZATION_PLUGIN_H
#define TRT_INSTANCE_NORMALIZATION_PLUGIN_H
#include <cudnn.h>
#include <iostream>
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
typedef unsigned short half_type;
class InstanceNormalizationDynamic final
: public nvinfer1::IPluginV2DynamicExt {
public:
InstanceNormalizationDynamic(const std::string& name, float epsilon);
InstanceNormalizationDynamic(const std::string& name, void const* serialData,
size_t serialLength);
InstanceNormalizationDynamic() = delete;
~InstanceNormalizationDynamic() override;
int getNbOutputs() const override;
// DynamicExt plugins returns DimsExprs class instead of Dims
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
int initialize() override;
void terminate() override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
// DynamicExt plugin supportsFormat update.
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
const char* getPluginType() const override;
const char* getPluginVersion() const override;
void destroy() override;
nvinfer1::IPluginV2DynamicExt* clone() const override;
void setPluginNamespace(const char* pluginNamespace) override;
const char* getPluginNamespace() const override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void attachToContext(cudnnContext* cudnn, cublasContext* cublas,
nvinfer1::IGpuAllocator* allocator) override;
void detachFromContext() override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;
private:
const std::string mLayerName;
float mEpsilon{};
cudnnHandle_t _cudnn_handle{};
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{};
std::string mPluginNamespace{};
};
class InstanceNormalizationDynamicCreator : public nvinfer1::IPluginCreator {
public:
InstanceNormalizationDynamicCreator();
~InstanceNormalizationDynamicCreator() override = default;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2DynamicExt* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2DynamicExt* deserializePlugin(
const char* name, const void* serialData, size_t serialLength) override;
void setPluginNamespace(const char* pluginNamespace) override;
const char* getPluginNamespace() const override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif // TRT_INSTANCE_NORMALIZATION_PLUGIN_H
/* // Modified from:
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. // https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_SERIALIZE_HPP #ifndef TRT_SERIALIZE_HPP
#define TRT_SERIALIZE_HPP #define TRT_SERIALIZE_HPP
#include <cassert> #include <cassert>
......
...@@ -91,7 +91,9 @@ def preprocess_onnx(onnx_model): ...@@ -91,7 +91,9 @@ def preprocess_onnx(onnx_model):
node_dict[output] = new_node node_dict[output] = new_node
nodes.insert(idx, new_node) nodes.insert(idx, new_node)
nodes.remove(node) nodes.remove(node)
elif node.op_type == 'InstanceNormalization':
# directly change op name
node.op_type = 'MMCVInstanceNormalization'
return onnx_model return onnx_model
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
try: try:
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt,
save_trt_engine) save_trt_engine)
except ImportError: except ImportError:
pytest.skip( pytest.skip(
...@@ -95,7 +95,7 @@ def test_roialign(): ...@@ -95,7 +95,7 @@ def test_roialign():
fp16_mode=fp16_mode, fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size) max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, ['input', 'rois'], ['roi_feat']) trt_model = TRTWrapper(trt_file, ['input', 'rois'], ['roi_feat'])
with torch.no_grad(): with torch.no_grad():
trt_outputs = trt_model({'input': input, 'rois': rois}) trt_outputs = trt_model({'input': input, 'rois': rois})
...@@ -155,7 +155,7 @@ def test_nms(): ...@@ -155,7 +155,7 @@ def test_nms():
fp16_mode=fp16_mode, fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size) max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, ['boxes', 'scores'], ['dets', 'inds']) trt_model = TRTWrapper(trt_file, ['boxes', 'scores'], ['dets', 'inds'])
with torch.no_grad(): with torch.no_grad():
trt_outputs = trt_model({'boxes': boxes, 'scores': scores}) trt_outputs = trt_model({'boxes': boxes, 'scores': scores})
...@@ -237,7 +237,7 @@ def test_batched_nms(): ...@@ -237,7 +237,7 @@ def test_batched_nms():
fp16_mode=fp16_mode, fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size) max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, input_names, output_names) trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad(): with torch.no_grad():
trt_outputs = trt_model({ trt_outputs = trt_model({
...@@ -311,7 +311,7 @@ def test_scatternd(): ...@@ -311,7 +311,7 @@ def test_scatternd():
max_workspace_size=max_workspace_size) max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, input_names, output_names) trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad(): with torch.no_grad():
trt_outputs = trt_model({'input': data.clone()}) trt_outputs = trt_model({'input': data.clone()})
...@@ -387,7 +387,7 @@ def test_deform_conv(): ...@@ -387,7 +387,7 @@ def test_deform_conv():
max_workspace_size=max_workspace_size) max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, input_names, output_names) trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad(): with torch.no_grad():
trt_outputs = trt_model({'input': x.clone()}) trt_outputs = trt_model({'input': x.clone()})
...@@ -463,7 +463,7 @@ def test_grid_sample(mode, padding_mode, align_corners): ...@@ -463,7 +463,7 @@ def test_grid_sample(mode, padding_mode, align_corners):
max_workspace_size=max_workspace_size) max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, input_names, output_names) trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad(): with torch.no_grad():
trt_outputs = trt_model({'input': input.clone(), 'grid': grid.clone()}) trt_outputs = trt_model({'input': input.clone(), 'grid': grid.clone()})
...@@ -555,7 +555,7 @@ def test_cummin_cummax(func: Callable): ...@@ -555,7 +555,7 @@ def test_cummin_cummax(func: Callable):
save_trt_engine(trt_engine, trt_file) save_trt_engine(trt_engine, trt_file)
# load and wrap TensorRT model # load and wrap TensorRT model
trt_model = TRTWraper(trt_file) trt_model = TRTWrapper(trt_file)
# remove trt model after loading # remove trt model after loading
if os.path.exists(trt_file): if os.path.exists(trt_file):
...@@ -575,3 +575,83 @@ def test_cummin_cummax(func: Callable): ...@@ -575,3 +575,83 @@ def test_cummin_cummax(func: Callable):
torch.testing.assert_allclose(trt_output, pytorch_output) torch.testing.assert_allclose(trt_output, pytorch_output)
torch.testing.assert_allclose(trt_indices, pytorch_indices) torch.testing.assert_allclose(trt_indices, pytorch_indices)
@pytest.mark.parametrize('dynamic_export', [True, False])
@pytest.mark.parametrize('fp16_mode', [True, False])
def test_instance_norm(dynamic_export, fp16_mode):
n, c, h, w = 2, 3, 10, 10
data = torch.randn(n, c, h, w).cuda()
norm = nn.InstanceNorm2d(c, affine=True)
wrapped_model = WrapFunction(norm).eval().cuda()
input_names = ['input']
output_names = ['output']
dynamic_axes = None
if dynamic_export:
dynamic_axes = {
'input': {
0: 'n',
2: 'h',
3: 'w',
},
'output': {
0: 'n',
2: 'h',
3: 'w',
},
}
with torch.no_grad():
torch.onnx.export(
wrapped_model, (data.clone(), ),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
if dynamic_export:
opt_shape_dict = {
'input':
[list(data.shape),
list(data.shape), [2 * n, c, 2 * h, 2 * w]],
}
else:
opt_shape_dict = {
'input': [list(data.shape),
list(data.shape),
list(data.shape)],
}
# trt config
max_workspace_size = 1 << 30
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWrapper(trt_file, input_names, output_names)
with torch.no_grad():
trt_outputs = trt_model({'input': data.clone()})
trt_results = trt_outputs['output']
# compute pytorch_output
with torch.no_grad():
pytorch_results = wrapped_model(data.clone())
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_results, trt_results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment