Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
07d681ac
Unverified
Commit
07d681ac
authored
Jul 09, 2021
by
qjqfl
Committed by
GitHub
Jul 09, 2021
Browse files
[Feature]: support tensorrt custom plugin `MMCVCornerPool` (#1179)
parent
2dc0a219
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
516 additions
and
0 deletions
+516
-0
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
+216
-0
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
+109
-0
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
+2
-0
mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp
mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp
+111
-0
tests/test_ops/test_tensorrt.py
tests/test_ops/test_tensorrt.py
+78
-0
No files found.
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
0 → 100644
View file @
07d681ac
#include "trt_corner_pool.hpp"
#include <assert.h>
#include "trt_serialize.hpp"
void
CornerPoolForwardLauncher_float
(
const
float
*
input
,
float
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
,
cudaStream_t
stream
);
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
CORNER_POOL_PLUGIN_NAME
{
"MMCVCornerPool"
};
}
// namespace
CornerPoolPluginDynamic
::
CornerPoolPluginDynamic
(
const
std
::
string
&
name
,
TRT_CORNER_POOL_TYPE
poolType
)
:
mLayerName
(
name
),
mPoolType
(
poolType
)
{}
CornerPoolPluginDynamic
::
CornerPoolPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
mLayerName
(
name
)
{
deserialize_value
(
&
data
,
&
length
,
&
mPoolType
);
}
CornerPoolPluginDynamic
::~
CornerPoolPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
CornerPoolPluginDynamic
::
clone
()
const
{
CornerPoolPluginDynamic
*
plugin
=
new
CornerPoolPluginDynamic
(
mLayerName
,
mPoolType
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
CornerPoolPluginDynamic
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
{
return
inputs
[
0
];
}
bool
CornerPoolPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
{
switch
(
pos
)
{
// input[0]
case
0
:
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
// output[0]
case
1
:
return
inOut
[
pos
].
type
==
inOut
[
0
].
type
&&
inOut
[
pos
].
format
==
inOut
[
0
].
format
;
default:
return
false
;
}
}
void
CornerPoolPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
{}
size_t
CornerPoolPluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
{
int
sizeof_dtype
=
mmcv
::
getElementSize
(
outputs
[
0
].
type
);
}
int
CornerPoolPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
{
const
void
*
input
=
inputs
[
0
];
void
*
output_value
=
outputs
[
0
];
const
int
batch_size
=
inputDesc
[
0
].
dims
.
d
[
0
];
const
int
channels
=
inputDesc
[
0
].
dims
.
d
[
1
];
const
int
height
=
inputDesc
[
0
].
dims
.
d
[
2
];
const
int
width
=
inputDesc
[
0
].
dims
.
d
[
3
];
CornerPoolForwardLauncher_float
((
float
*
)
input
,
(
float
*
)
output_value
,
batch_size
,
channels
,
height
,
width
,
int
(
mPoolType
),
stream
);
return
0
;
}
nvinfer1
::
DataType
CornerPoolPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
{
return
inputTypes
[
0
];
}
// IPluginV2 Methods
const
char
*
CornerPoolPluginDynamic
::
getPluginType
()
const
{
switch
(
mPoolType
)
{
case
TRT_CORNER_POOL_TYPE
::
TRT_TOP_POOL
:
case
TRT_CORNER_POOL_TYPE
::
TRT_BOTTOM_POOL
:
case
TRT_CORNER_POOL_TYPE
::
TRT_LEFT_POOL
:
case
TRT_CORNER_POOL_TYPE
::
TRT_RIGHT_POOL
:
return
CORNER_POOL_PLUGIN_NAME
;
default:
return
"UnknownpoolType"
;
}
}
const
char
*
CornerPoolPluginDynamic
::
getPluginVersion
()
const
{
return
PLUGIN_VERSION
;
}
int
CornerPoolPluginDynamic
::
getNbOutputs
()
const
{
return
1
;
}
int
CornerPoolPluginDynamic
::
initialize
()
{
return
0
;
}
void
CornerPoolPluginDynamic
::
terminate
()
{}
size_t
CornerPoolPluginDynamic
::
getSerializationSize
()
const
{
return
sizeof
(
mPoolType
);
}
void
CornerPoolPluginDynamic
::
serialize
(
void
*
buffer
)
const
{
serialize_value
(
&
buffer
,
mPoolType
);
}
void
CornerPoolPluginDynamic
::
destroy
()
{
// This gets called when the network containing plugin is destroyed
delete
this
;
}
void
CornerPoolPluginDynamic
::
setPluginNamespace
(
const
char
*
libNamespace
)
{
mNamespace
=
libNamespace
;
}
const
char
*
CornerPoolPluginDynamic
::
getPluginNamespace
()
const
{
return
mNamespace
.
c_str
();
}
CornerPoolPluginDynamicCreator
::
CornerPoolPluginDynamicCreator
()
{
mPluginAttributes
.
clear
();
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"mode"
));
mFC
.
nbFields
=
mPluginAttributes
.
size
();
mFC
.
fields
=
mPluginAttributes
.
data
();
}
const
char
*
CornerPoolPluginDynamicCreator
::
getPluginName
()
const
{
return
CORNER_POOL_PLUGIN_NAME
;
}
const
char
*
CornerPoolPluginDynamicCreator
::
getPluginVersion
()
const
{
return
PLUGIN_VERSION
;
}
const
nvinfer1
::
PluginFieldCollection
*
CornerPoolPluginDynamicCreator
::
getFieldNames
()
{
return
&
mFC
;
}
nvinfer1
::
IPluginV2
*
CornerPoolPluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
{
TRT_CORNER_POOL_TYPE
poolType
;
int
poolMode
=
-
1
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
i
++
)
{
if
(
fc
->
fields
[
i
].
data
==
nullptr
)
{
continue
;
}
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"mode"
)
==
0
)
{
poolMode
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
}
assert
(
poolMode
>=
0
&&
poolMode
<=
3
);
switch
(
poolMode
)
{
case
0
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_TOP_POOL
;
break
;
case
1
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_BOTTOM_POOL
;
break
;
case
2
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_LEFT_POOL
;
break
;
case
3
:
poolType
=
TRT_CORNER_POOL_TYPE
::
TRT_RIGHT_POOL
;
break
;
default:
break
;
}
CornerPoolPluginDynamic
*
plugin
=
new
CornerPoolPluginDynamic
(
name
,
poolType
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
CornerPoolPluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
{
// This object will be deleted when the network is destroyed, which will
// call FCPluginDynamic::destroy()
auto
plugin
=
new
CornerPoolPluginDynamic
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
void
CornerPoolPluginDynamicCreator
::
setPluginNamespace
(
const
char
*
libNamespace
)
{
mNamespace
=
libNamespace
;
}
const
char
*
CornerPoolPluginDynamicCreator
::
getPluginNamespace
()
const
{
return
mNamespace
.
c_str
();
}
mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
0 → 100644
View file @
07d681ac
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
template
<
typename
scalar_t
>
__global__
void
top_bottom_pool_kernel
(
const
scalar_t
*
input
,
scalar_t
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
)
{
const
int
nthreads
=
batch_size
*
channels
*
width
;
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
int
n_idx
=
index
/
(
channels
*
width
);
// batch
int
w_idx
=
index
%
width
;
// width
int
c_idx
=
(
index
/
width
)
%
channels
;
// channels
int
offset_n
=
n_idx
*
channels
*
width
*
height
;
int
offset_n_c
=
offset_n
+
c_idx
*
width
*
height
;
int
direction
=
-
1
;
// in [-1, 1], default for TopPool
int
index_start
=
height
-
2
;
// default for TopPool
// pool_type in [0, 1]
if
(
pool_type
==
0
)
{
// TopPool
// directly copy the most bottom value from input to output
output
[
offset_n_c
+
(
height
-
1
)
*
width
+
w_idx
]
=
input
[
offset_n_c
+
(
height
-
1
)
*
width
+
w_idx
];
}
else
{
// BottomPool
// directly copy the most top value from input to output
output
[
offset_n_c
+
w_idx
]
=
input
[
offset_n_c
+
w_idx
];
index_start
=
1
;
direction
=
1
;
}
// do pool
for
(
int
h
=
index_start
;
h
>=
0
&&
h
<
height
;
h
+=
direction
)
{
output
[
offset_n_c
+
h
*
width
+
w_idx
]
=
max
(
output
[
offset_n_c
+
(
h
-
direction
)
*
width
+
w_idx
],
input
[
offset_n_c
+
h
*
width
+
w_idx
]);
}
}
}
template
<
typename
scalar_t
>
__global__
void
left_right_pool_kernel
(
const
scalar_t
*
input
,
scalar_t
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
)
{
const
int
nthreads
=
batch_size
*
channels
*
height
;
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
int
n_idx
=
index
/
(
channels
*
height
);
// batch
int
h_idx
=
index
%
height
;
// height
int
c_idx
=
(
index
/
height
)
%
channels
;
// channels
int
offset_n
=
n_idx
*
channels
*
width
*
height
;
int
offset_n_c
=
offset_n
+
c_idx
*
width
*
height
;
int
offset_n_c_h
=
offset_n_c
+
h_idx
*
width
;
int
direction
=
-
1
;
// in [-1, 1], default for LeftPool
int
index_start
=
width
-
2
;
// default for LeftPool
// pool_type in [2, 3]
if
(
pool_type
==
2
)
{
// LeftPool
// directly copy the most right value from input to output
output
[
offset_n_c_h
+
width
-
1
]
=
input
[
offset_n_c_h
+
width
-
1
];
}
else
{
// RightPool
// directly copy the most left value from input to output
output
[
offset_n_c_h
]
=
input
[
offset_n_c_h
];
index_start
=
1
;
direction
=
1
;
}
// do pool
for
(
int
w
=
index_start
;
w
>=
0
&&
w
<
width
;
w
+=
direction
)
{
output
[
offset_n_c_h
+
w
]
=
max
(
output
[
offset_n_c_h
+
w
-
direction
],
input
[
offset_n_c_h
+
w
]);
}
}
}
template
<
typename
scalar_t
>
void
CornerPoolForwardLauncher
(
const
scalar_t
*
input
,
scalar_t
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
,
cudaStream_t
stream
)
{
int
nthreads
=
-
1
,
col_block
=
-
1
;
switch
(
pool_type
)
{
case
0
:
case
1
:
nthreads
=
batch_size
*
channels
*
width
;
col_block
=
DIVUP
(
nthreads
,
THREADS_PER_BLOCK
);
top_bottom_pool_kernel
<
scalar_t
>
<<<
col_block
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
input
,
output
,
batch_size
,
channels
,
height
,
width
,
pool_type
);
break
;
case
2
:
case
3
:
nthreads
=
batch_size
*
channels
*
height
;
col_block
=
DIVUP
(
nthreads
,
THREADS_PER_BLOCK
);
left_right_pool_kernel
<
scalar_t
>
<<<
col_block
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
input
,
output
,
batch_size
,
channels
,
height
,
width
,
pool_type
);
break
;
}
}
void
CornerPoolForwardLauncher_float
(
const
float
*
input
,
float
*
output
,
const
int
batch_size
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pool_type
,
cudaStream_t
stream
)
{
CornerPoolForwardLauncher
<
float
>
(
input
,
output
,
batch_size
,
channels
,
height
,
width
,
pool_type
,
stream
);
}
mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
View file @
07d681ac
#include "trt_plugin.hpp"
#include "trt_corner_pool.hpp"
#include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp"
...
...
@@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN
(
RoIAlignPluginDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
ONNXScatterNDDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
InstanceNormalizationDynamicCreator
);
REGISTER_TENSORRT_PLUGIN
(
CornerPoolPluginDynamicCreator
);
extern
"C"
{
bool
initLibMMCVInferPlugins
()
{
return
true
;
}
...
...
mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp
0 → 100644
View file @
07d681ac
#ifndef TRT_CORNER_POOL_HPP
#define TRT_CORNER_POOL_HPP
#include <string>
#include <vector>
#include "trt_plugin_helper.hpp"
enum
TRT_CORNER_POOL_TYPE
{
TRT_TOP_POOL
=
0
,
TRT_BOTTOM_POOL
=
1
,
TRT_LEFT_POOL
=
2
,
TRT_RIGHT_POOL
=
3
};
// implement of CornerPool
class
CornerPoolPluginDynamic
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
CornerPoolPluginDynamic
(
const
std
::
string
&
name
,
TRT_CORNER_POOL_TYPE
poolType
);
CornerPoolPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
CornerPoolPluginDynamic
()
=
delete
;
~
CornerPoolPluginDynamic
();
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
int
getNbOutputs
()
const
override
;
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
void
destroy
()
override
;
void
setPluginNamespace
(
const
char
*
pluginNamespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
protected:
const
std
::
string
mLayerName
;
std
::
string
mNamespace
;
TRT_CORNER_POOL_TYPE
mPoolType
;
protected:
// To prevent compiler warnings.
using
nvinfer1
::
IPluginV2DynamicExt
::
canBroadcastInputAcrossBatch
;
using
nvinfer1
::
IPluginV2DynamicExt
::
configurePlugin
;
using
nvinfer1
::
IPluginV2DynamicExt
::
enqueue
;
using
nvinfer1
::
IPluginV2DynamicExt
::
getOutputDimensions
;
using
nvinfer1
::
IPluginV2DynamicExt
::
getWorkspaceSize
;
using
nvinfer1
::
IPluginV2DynamicExt
::
isOutputBroadcastAcrossBatch
;
using
nvinfer1
::
IPluginV2DynamicExt
::
supportsFormat
;
};
// CornerPool creator
class
CornerPoolPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
CornerPoolPluginDynamicCreator
();
const
char
*
getPluginName
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
override
;
void
setPluginNamespace
(
const
char
*
pluginNamespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
protected:
nvinfer1
::
PluginFieldCollection
mFC
;
std
::
vector
<
nvinfer1
::
PluginField
>
mPluginAttributes
;
std
::
string
mNamespace
;
};
#endif TRT_CORNER_POOL_HPP // TRT_CORNER_POOL_HPP
tests/test_ops/test_tensorrt.py
View file @
07d681ac
...
...
@@ -727,3 +727,81 @@ def test_instance_norm(dynamic_export, fp16_mode):
if
os
.
path
.
exists
(
trt_file
):
os
.
remove
(
trt_file
)
assert
torch
.
allclose
(
pytorch_results
,
trt_results
)
@
pytest
.
mark
.
parametrize
(
'mode'
,
[
'top'
,
'bottom'
,
'left'
,
'right'
])
def
test_corner_pool
(
mode
):
try
:
from
mmcv.ops
import
CornerPool
except
(
ImportError
,
ModuleNotFoundError
):
pytest
.
skip
(
'test requires compilation'
)
opset
=
11
# register custom op `mmcv::MMCVCornerPool`
from
mmcv.onnx.symbolic
import
register_extra_symbolics
register_extra_symbolics
(
opset
)
# trt config
fp16_mode
=
False
max_workspace_size
=
1
<<
30
inputs
=
[
# (n, c, h, w)
torch
.
rand
((
2
,
3
,
5
,
5
)),
torch
.
rand
((
1
,
2
,
4
,
6
)),
torch
.
rand
((
2
,
1
,
3
,
2
)),
]
class
CornerPoolWrapper
(
CornerPool
):
def
__init__
(
self
,
mode
):
super
(
CornerPoolWrapper
,
self
).
__init__
(
mode
)
def
forward
(
self
,
x
):
# no use `torch.cummax`, instead `corner_pool` is used
# for various torch version
return
self
.
corner_pool
.
apply
(
x
)
wrapped_model
=
CornerPoolWrapper
(
mode
).
cuda
()
for
input
in
inputs
:
input
=
input
.
cuda
()
with
torch
.
no_grad
():
torch
.
onnx
.
export
(
wrapped_model
,
(
input
,
),
onnx_file
,
export_params
=
True
,
keep_initializers_as_inputs
=
True
,
input_names
=
[
'input'
],
output_names
=
[
'output'
],
opset_version
=
opset
)
onnx_model
=
onnx
.
load
(
onnx_file
)
# create trt engine and wraper
opt_shape_dict
=
{
'input'
:
[
list
(
input
.
shape
),
list
(
input
.
shape
),
list
(
input
.
shape
)],
}
trt_engine
=
onnx2trt
(
onnx_model
,
opt_shape_dict
,
fp16_mode
=
fp16_mode
,
max_workspace_size
=
max_workspace_size
)
save_trt_engine
(
trt_engine
,
trt_file
)
trt_model
=
TRTWrapper
(
trt_file
,
[
'input'
],
[
'output'
])
with
torch
.
no_grad
():
trt_outputs
=
trt_model
({
'input'
:
input
})
trt_pool_feat
=
trt_outputs
[
'output'
]
# compute pytorch_output
with
torch
.
no_grad
():
pytorch_pool_feat
=
wrapped_model
(
input
)
# allclose
if
os
.
path
.
exists
(
onnx_file
):
os
.
remove
(
onnx_file
)
if
os
.
path
.
exists
(
trt_file
):
os
.
remove
(
trt_file
)
assert
torch
.
allclose
(
pytorch_pool_feat
,
trt_pool_feat
,
atol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment