Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
07d681ac
You need to sign in or sign up before continuing.
Unverified
Commit
07d681ac
authored
Jul 09, 2021
by
qjqfl
Committed by
GitHub
Jul 09, 2021
Browse files
[Feature]: support tensorrt custom plugin `MMCVCornerPool` (#1179)
parent
2dc0a219
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
516 additions
and
0 deletions
+516
-0
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
+216
-0
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
+109
-0
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
+2
-0
mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp
mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp
+111
-0
tests/test_ops/test_tensorrt.py
tests/test_ops/test_tensorrt.py
+78
-0
No files found.
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
0 → 100644
View file @
07d681ac
#include "trt_corner_pool.hpp"
#include <assert.h>
#include "trt_serialize.hpp"
void
CornerPoolForwardLauncher_float
(
const
float
*
input
,
float
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
,
cudaStream_t
stream
);
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
CORNER_POOL_PLUGIN_NAME
{
"MMCVCornerPool"
};
}
// namespace
CornerPoolPluginDynamic
::
CornerPoolPluginDynamic
(
const
std
::
string
&
name
,
TRT_CORNER_POOL_TYPE
poolType
)
:
mLayerName
(
name
),
mPoolType
(
poolType
)
{}
CornerPoolPluginDynamic
::
CornerPoolPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
mLayerName
(
name
)
{
deserialize_value
(
&
data
,
&
length
,
&
mPoolType
);
}
CornerPoolPluginDynamic
::~
CornerPoolPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
CornerPoolPluginDynamic
::
clone
()
const
{
CornerPoolPluginDynamic
*
plugin
=
new
CornerPoolPluginDynamic
(
mLayerName
,
mPoolType
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
CornerPoolPluginDynamic
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
{
return
inputs
[
0
];
}
bool
CornerPoolPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
{
switch
(
pos
)
{
// input[0]
case
0
:
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
// output[0]
case
1
:
return
inOut
[
pos
].
type
==
inOut
[
0
].
type
&&
inOut
[
pos
].
format
==
inOut
[
0
].
format
;
default:
return
false
;
}
}
void
CornerPoolPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
{}
size_t
CornerPoolPluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
{
int
sizeof_dtype
=
mmcv
::
getElementSize
(
outputs
[
0
].
type
);
}
int
CornerPoolPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
{
const
void
*
input
=
inputs
[
0
];
void
*
output_value
=
outputs
[
0
];
const
int
batch_size
=
inputDesc
[
0
].
dims
.
d
[
0
];
const
int
channels
=
inputDesc
[
0
].
dims
.
d
[
1
];
const
int
height
=
inputDesc
[
0
].
dims
.
d
[
2
];
const
int
width
=
inputDesc
[
0
].
dims
.
d
[
3
];
CornerPoolForwardLauncher_float
((
float
*
)
input
,
(
float
*
)
output_value
,
batch_size
,
channels
,
height
,
width
,
int
(
mPoolType
),
stream
);
return
0
;
}
nvinfer1
::
DataType
CornerPoolPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
{
return
inputTypes
[
0
];
}
// IPluginV2 Methods
const
char
*
CornerPoolPluginDynamic
::
getPluginType
()
const
{
switch
(
mPoolType
)
{
case
TRT_CORNER_POOL_TYPE
::
TRT_TOP_POOL
:
case
TRT_CORNER_POOL_TYPE
::
TRT_BOTTOM_POOL
:
case
TRT_CORNER_POOL_TYPE
::
TRT_LEFT_POOL
:
case
TRT_CORNER_POOL_TYPE
::
TRT_RIGHT_POOL
:
return
CORNER_POOL_PLUGIN_NAME
;
default:
return
"UnknownpoolType"
;
}
}
const
char
*
CornerPoolPluginDynamic
::
getPluginVersion
()
const
{
return
PLUGIN_VERSION
;
}
int
CornerPoolPluginDynamic
::
getNbOutputs
()
const
{
return
1
;
}
int
CornerPoolPluginDynamic
::
initialize
()
{
return
0
;
}
void
CornerPoolPluginDynamic
::
terminate
()
{}
size_t
CornerPoolPluginDynamic
::
getSerializationSize
()
const
{
return
sizeof
(
mPoolType
);
}
void
CornerPoolPluginDynamic
::
serialize
(
void
*
buffer
)
const
{
serialize_value
(
&
buffer
,
mPoolType
);
}
void
CornerPoolPluginDynamic
::
destroy
()
{
// This gets called when the network containing plugin is destroyed
delete
this
;
}
void
CornerPoolPluginDynamic
::
setPluginNamespace
(
const
char
*
libNamespace
)
{
mNamespace
=
libNamespace
;
}
const
char
*
CornerPoolPluginDynamic
::
getPluginNamespace
()
const
{
return
mNamespace
.
c_str
();
}
CornerPoolPluginDynamicCreator
::
CornerPoolPluginDynamicCreator
()
{
mPluginAttributes
.
clear
();
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"mode"
));
mFC
.
nbFields
=
mPluginAttributes
.
size
();
mFC
.
fields
=
mPluginAttributes
.
data
();
}
const
char
*
CornerPoolPluginDynamicCreator
::
getPluginName
()
const
{
return
CORNER_POOL_PLUGIN_NAME
;
}
const
char
*
CornerPoolPluginDynamicCreator
::
getPluginVersion
()
const
{
return
PLUGIN_VERSION
;
}
const
nvinfer1
::
PluginFieldCollection
*
CornerPoolPluginDynamicCreator
::
getFieldNames
()
{
return
&
mFC
;
}
nvinfer1
::
IPluginV2
*
CornerPoolPluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
{
TRT_CORNER_POOL_TYPE
poolType
;
int
poolMode
=
-
1
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
i
++
)
{
if
(
fc
->
fields
[
i
].
data
==
nullptr
)
{
continue
;
}
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"mode"
)
==
0
)
{
poolMode
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
}
assert
(
poolMode
>=
0
&&
poolMode
<=
3
);
switch
(
poolMode
)
{
case
0
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_TOP_POOL
;
break
;
case
1
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_BOTTOM_POOL
;
break
;
case
2
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_LEFT_POOL
;
break
;
case
3
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_RIGHT_POOL
;
break
;
default:
break
;
}
CornerPoolPluginDynamic
*
plugin
=
new
CornerPoolPluginDynamic
(
name
,
poolType
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
CornerPoolPluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
{
// This object will be deleted when the network is destroyed, which will
// call FCPluginDynamic::destroy()
auto
plugin
=
new
CornerPoolPluginDynamic
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
void
CornerPoolPluginDynamicCreator
::
setPluginNamespace
(
const
char
*
libNamespace
)
{
mNamespace
=
libNamespace
;
}
const
char
*
CornerPoolPluginDynamicCreator
::
getPluginNamespace
()
const
{
return
mNamespace
.
c_str
();
}
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
0 → 100644
View file @
07d681ac
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
template
<
typename
scalar_t
>
__global__
void
top_bottom_pool_kernel
(
const
scalar_t
*
input
,
scalar_t
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
)
{
const
int
nthreads
=
batch_size
*
channels
*
width
;
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
int
n_idx
=
index
/
(
channels
*
width
);
// batch
int
w_idx
=
index
%
width
;
// width
int
c_idx
=
(
index
/
width
)
%
channels
;
// channels
int
offset_n
=
n_idx
*
channels
*
width
*
height
;
int
offset_n_c
=
offset_n
+
c_idx
*
width
*
height
;
int
direction
=
-
1
;
// in [-1, 1], default for TopPool
int
index_start
=
height
-
2
;
// default for TopPool
// pool_type in [0, 1]
if
(
pool_type
==
0
)
{
// TopPool
// directly copy the most bottom value from input to output
output
[
offset_n_c
+
(
height
-
1
)
*
width
+
w_idx
]
=
input
[
offset_n_c
+
(
height
-
1
)
*
width
+
w_idx
];
}
else
{
// BottomPool
// directly copy the most top value from input to output
output
[
offset_n_c
+
w_idx
]
=
input
[
offset_n_c
+
w_idx
];
index_start
=
1
;
direction
=
1
;
}
// do pool
for
(
int
h
=
index_start
;
h
>=
0
&&
h
<
height
;
h
+=
direction
)
{
output
[
offset_n_c
+
h
*
width
+
w_idx
]
=
max
(
output
[
offset_n_c
+
(
h
-
direction
)
*
width
+
w_idx
],
input
[
offset_n_c
+
h
*
width
+
w_idx
]);
}
}
}
template
<
typename
scalar_t
>
__global__
void
left_right_pool_kernel
(
const
scalar_t
*
input
,
scalar_t
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
)
{
const
int
nthreads
=
batch_size
*
channels
*
height
;
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
int
n_idx
=
index
/
(
channels
*
height
);
// batch
int
h_idx
=
index
%
height
;
// height
int
c_idx
=
(
index
/
height
)
%
channels
;
// channels
int
offset_n
=
n_idx
*
channels
*
width
*
height
;
int
offset_n_c
=
offset_n
+
c_idx
*
width
*
height
;
int
offset_n_c_h
=
offset_n_c
+
h_idx
*
width
;
int
direction
=
-
1
;
// in [-1, 1], default for LeftPool
int
index_start
=
width
-
2
;
// default for LeftPool
// pool_type in [2, 3]
if
(
pool_type
==
2
)
{
// LeftPool
// directly copy the most right value from input to output
output
[
offset_n_c_h
+
width
-
1
]
=
input
[
offset_n_c_h
+
width
-
1
];
}
else
{
// RightPool
// directly copy the most left value from input to output
output
[
offset_n_c_h
]
=
input
[
offset_n_c_h
];
index_start
=
1
;
direction
=
1
;
}
// do pool
for
(
int
w
=
index_start
;
w
>=
0
&&
w
<
width
;
w
+=
direction
)
{
output
[
offset_n_c_h
+
w
]
=
max
(
output
[
offset_n_c_h
+
w
-
direction
],
input
[
offset_n_c_h
+
w
]);
}
}
}
template
<
typename
scalar_t
>
void
CornerPoolForwardLauncher
(
const
scalar_t
*
input
,
scalar_t
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
,
cudaStream_t
stream
)
{
int
nthreads
=
-
1
,
col_block
=
-
1
;
switch
(
pool_type
)
{
case
0
:
case
1
:
nthreads
=
batch_size
*
channels
*
width
;
col_block
=
DIVUP
(
nthreads
,
THREADS_PER_BLOCK
);
top_bottom_pool_kernel
<
scalar_t
>
<<<
col_block
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
input
,
output
,
batch_size
,
channels
,
height
,
width
,
pool_type
);
break
;
case
2
:
case
3
:
nthreads
=
batch_size
*
channels
*
height
;
col_block
=
DIVUP
(
nthreads
,
THREADS_PER_BLOCK
);
left_right_pool_kernel
<
scalar_t
>
<<<
col_block
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
input
,
output
,
batch_size
,
channels
,
height
,
width
,
pool_type
);
break
;
}
}
void
CornerPoolForwardLauncher_float
(
const
float
*
input
,
float
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
,
cudaStream_t
stream
)
{
CornerPoolForwardLauncher
<
float
>
(
input
,
output
,
batch_size
,
channels
,
height
,
width
,
pool_type
,
stream
);
}
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
View file @
07d681ac
#include "trt_plugin.hpp"
#include "trt_plugin.hpp"
#include "trt_corner_pool.hpp"
#include "trt_cummaxmin.hpp"
#include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp"
#include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp"
#include "trt_grid_sampler.hpp"
...
@@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
...
@@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN
(
RoIAlignPluginDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
RoIAlignPluginDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
ONNXScatterNDDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
ONNXScatterNDDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
InstanceNormalizationDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
InstanceNormalizationDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
CornerPoolPluginDynamicCreator
);
extern
"C"
{
extern
"C"
{
bool
initLibMMCVInferPlugins
()
{
return
true
;
}
bool
initLibMMCVInferPlugins
()
{
return
true
;
}
...
...
mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp
0 → 100644
View file @
07d681ac
#ifndef TRT_CORNER_POOL_HPP
#define TRT_CORNER_POOL_HPP
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
enum
TRT_CORNER_POOL_TYPE
{
TRT_TOP_POOL
=
0
,
TRT_BOTTOM_POOL
=
1
,
TRT_LEFT_POOL
=
2
,
TRT_RIGHT_POOL
=
3
};
// implement of CornerPool
class
CornerPoolPluginDynamic
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
CornerPoolPluginDynamic
(
const
std
::
string
&
name
,
TRT_CORNER_POOL_TYPE
poolType
);
CornerPoolPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
CornerPoolPluginDynamic
()
=
delete
;
~
CornerPoolPluginDynamic
();
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
int
getNbOutputs
()
const
override
;
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
void
destroy
()
override
;
void
setPluginNamespace
(
const
char
*
pluginNamespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
protected:
const
std
::
string
mLayerName
;
std
::
string
mNamespace
;
TRT_CORNER_POOL_TYPE
mPoolType
;
protected:
// To prevent compiler warnings.
using
nvinfer1
::
IPluginV2DynamicExt
::
canBroadcastInputAcrossBatch
;
using
nvinfer1
::
IPluginV2DynamicExt
::
configurePlugin
;
using
nvinfer1
::
IPluginV2DynamicExt
::
enqueue
;
using
nvinfer1
::
IPluginV2DynamicExt
::
getOutputDimensions
;
using
nvinfer1
::
IPluginV2DynamicExt
::
getWorkspaceSize
;
using
nvinfer1
::
IPluginV2DynamicExt
::
isOutputBroadcastAcrossBatch
;
using
nvinfer1
::
IPluginV2DynamicExt
::
supportsFormat
;
};
// CornerPool creator
class
CornerPoolPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
CornerPoolPluginDynamicCreator
();
const
char
*
getPluginName
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
override
;
void
setPluginNamespace
(
const
char
*
pluginNamespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
protected:
nvinfer1
::
PluginFieldCollection
mFC
;
std
::
vector
<
nvinfer1
::
PluginField
>
mPluginAttributes
;
std
::
string
mNamespace
;
};
#endif TRT_CORNER_POOL_HPP // TRT_CORNER_POOL_HPP
tests/test_ops/test_tensorrt.py
View file @
07d681ac
...
@@ -727,3 +727,81 @@ def test_instance_norm(dynamic_export, fp16_mode):
...
@@ -727,3 +727,81 @@ def test_instance_norm(dynamic_export, fp16_mode):
if
os
.
path
.
exists
(
trt_file
):
if
os
.
path
.
exists
(
trt_file
):
os
.
remove
(
trt_file
)
os
.
remove
(
trt_file
)
assert
torch
.
allclose
(
pytorch_results
,
trt_results
)
assert
torch
.
allclose
(
pytorch_results
,
trt_results
)
@
pytest
.
mark
.
parametrize
(
'mode'
,
[
'top'
,
'bottom'
,
'left'
,
'right'
])
def
test_corner_pool
(
mode
):
try
:
from
mmcv.ops
import
CornerPool
except
(
ImportError
,
ModuleNotFoundError
):
pytest
.
skip
(
'test requires compilation'
)
opset
=
11
# register custom op `mmcv::MMCVCornerPool`
from
mmcv.onnx.symbolic
import
register_extra_symbolics
register_extra_symbolics
(
opset
)
# trt config
fp16_mode
=
False
max_workspace_size
=
1
<<
30
inputs
=
[
# (n, c, h, w)
torch
.
rand
((
2
,
3
,
5
,
5
)),
torch
.
rand
((
1
,
2
,
4
,
6
)),
torch
.
rand
((
2
,
1
,
3
,
2
)),
]
class
CornerPoolWrapper
(
CornerPool
):
def
__init__
(
self
,
mode
):
super
(
CornerPoolWrapper
,
self
).
__init__
(
mode
)
def
forward
(
self
,
x
):
# no use `torch.cummax`, instead `corner_pool` is used
# for various torch version
return
self
.
corner_pool
.
apply
(
x
)
wrapped_model
=
CornerPoolWrapper
(
mode
).
cuda
()
for
input
in
inputs
:
input
=
input
.
cuda
()
with
torch
.
no_grad
():
torch
.
onnx
.
export
(
wrapped_model
,
(
input
,
),
onnx_file
,
export_params
=
True
,
keep_initializers_as_inputs
=
True
,
input_names
=
[
'input'
],
output_names
=
[
'output'
],
opset_version
=
opset
)
onnx_model
=
onnx
.
load
(
onnx_file
)
# create trt engine and wraper
opt_shape_dict
=
{
'input'
:
[
list
(
input
.
shape
),
list
(
input
.
shape
),
list
(
input
.
shape
)],
}
trt_engine
=
onnx2trt
(
onnx_model
,
opt_shape_dict
,
fp16_mode
=
fp16_mode
,
max_workspace_size
=
max_workspace_size
)
save_trt_engine
(
trt_engine
,
trt_file
)
trt_model
=
TRTWrapper
(
trt_file
,
[
'input'
],
[
'output'
])
with
torch
.
no_grad
():
trt_outputs
=
trt_model
({
'input'
:
input
})
trt_pool_feat
=
trt_outputs
[
'output'
]
# compute pytorch_output
with
torch
.
no_grad
():
pytorch_pool_feat
=
wrapped_model
(
input
)
# allclose
if
os
.
path
.
exists
(
onnx_file
):
os
.
remove
(
onnx_file
)
if
os
.
path
.
exists
(
trt_file
):
os
.
remove
(
trt_file
)
assert
torch
.
allclose
(
pytorch_pool_feat
,
trt_pool_feat
,
atol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment