Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdeploy
Commits
546b4279
Commit
546b4279
authored
Jun 25, 2025
by
limm
Browse files
add csrc and mmdeploy module
parent
502f4fb9
Pipeline
#2810
canceled with stages
Changes
447
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2107 additions
and
0 deletions
+2107
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp
...ulti_level_roi_align/trt_multi_level_roi_align_kernel.hpp
+13
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp
...l_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp
+228
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp
...l_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp
+79
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu
...ted_roi_align/trt_multi_level_rotated_roi_align_kernel.cu
+164
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp
...ed_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp
+13
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp
...s/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp
+173
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp
...s/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp
+70
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu
...orrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu
+64
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh
...rrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh
+257
-0
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp
...rrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp
+15
-0
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp
...mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp
+241
-0
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp
...mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp
+72
-0
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu
...oy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu
+107
-0
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp
...y/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp
+15
-0
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp
...ed_dot_product_attention/scaled_dot_product_attention.cpp
+183
-0
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp
...ed_dot_product_attention/scaled_dot_product_attention.hpp
+73
-0
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu
..._product_attention/scaled_dot_product_attention_kernel.cu
+103
-0
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp
...product_attention/scaled_dot_product_attention_kernel.hpp
+17
-0
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp
...mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp
+156
-0
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp
...mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp
+64
-0
No files found.
Too many changes to show.
To preserve performance only
447 of 447+
files are displayed.
Plain diff
Email patch
csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP
#define TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP
#include <cuda_runtime.h>
template
<
typename
T
>
void
multi_level_roi_align
(
T
*
output
,
const
T
*
rois
,
int
num_rois
,
const
void
*
const
*
feats
,
int
num_feats
,
int
n
,
int
c
,
int
*
h
,
int
*
w
,
float
*
strides
,
int
aligned_height
,
int
aligned_width
,
int
pool_mode
,
int
sample_num
,
float
roi_scale_factor
,
int
finest_scale
,
bool
aligned
,
cudaStream_t
stream
);
#endif // TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "trt_multi_level_rotated_roi_align.hpp"
#include <assert.h>
#include <chrono>
#include "trt_multi_level_rotated_roi_align_kernel.hpp"
#include "trt_plugin_helper.hpp"
#include "trt_serialize.hpp"
namespace
mmdeploy
{
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
PLUGIN_NAME
{
"MMCVMultiLevelRotatedRoiAlign"
};
}
// namespace
TRTMultiLevelRotatedRoiAlign
::
TRTMultiLevelRotatedRoiAlign
(
const
std
::
string
&
name
,
int
alignedHeight
,
int
alignedWidth
,
int
clockwise
,
int
sampleNum
,
const
std
::
vector
<
float
>
&
featmapStrides
,
float
roiScaleFactor
,
int
finestScale
,
bool
aligned
)
:
TRTPluginBase
(
name
),
mAlignedHeight
(
alignedHeight
),
mAlignedWidth
(
alignedWidth
),
mClockwise
(
clockwise
),
mSampleNum
(
sampleNum
),
mFeatmapStrides
(
featmapStrides
),
mRoiScaleFactor
(
roiScaleFactor
),
mFinestScale
(
finestScale
),
mAligned
(
aligned
)
{}
TRTMultiLevelRotatedRoiAlign
::
TRTMultiLevelRotatedRoiAlign
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
TRTPluginBase
(
name
)
{
deserialize_value
(
&
data
,
&
length
,
&
mAlignedHeight
);
deserialize_value
(
&
data
,
&
length
,
&
mAlignedWidth
);
deserialize_value
(
&
data
,
&
length
,
&
mClockwise
);
deserialize_value
(
&
data
,
&
length
,
&
mSampleNum
);
deserialize_value
(
&
data
,
&
length
,
&
mRoiScaleFactor
);
deserialize_value
(
&
data
,
&
length
,
&
mFinestScale
);
deserialize_value
(
&
data
,
&
length
,
&
mAligned
);
deserialize_value
(
&
data
,
&
length
,
&
mFeatmapStrides
);
}
nvinfer1
::
IPluginV2DynamicExt
*
TRTMultiLevelRotatedRoiAlign
::
clone
()
const
TRT_NOEXCEPT
{
TRTMultiLevelRotatedRoiAlign
*
plugin
=
new
TRTMultiLevelRotatedRoiAlign
(
mLayerName
,
mAlignedHeight
,
mAlignedWidth
,
mClockwise
,
mSampleNum
,
mFeatmapStrides
,
mRoiScaleFactor
,
mFinestScale
,
mAligned
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
TRTMultiLevelRotatedRoiAlign
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
// warning, nbInputs should equal to mFeatmapStrides.size() + 1
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
4
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
inputs
[
1
].
d
[
1
];
ret
.
d
[
2
]
=
exprBuilder
.
constant
(
mAlignedHeight
);
ret
.
d
[
3
]
=
exprBuilder
.
constant
(
mAlignedWidth
);
return
ret
;
}
bool
TRTMultiLevelRotatedRoiAlign
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
return
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
void
TRTMultiLevelRotatedRoiAlign
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
// Validate input arguments
ASSERT
(
nbOutputs
==
1
);
ASSERT
(
nbInputs
>=
1
);
mFeatmapStrides
=
std
::
vector
<
float
>
(
mFeatmapStrides
.
begin
(),
mFeatmapStrides
.
begin
()
+
nbInputs
-
1
);
}
size_t
TRTMultiLevelRotatedRoiAlign
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
int
TRTMultiLevelRotatedRoiAlign
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
int
num_rois
=
inputDesc
[
0
].
dims
.
d
[
0
];
int
batch_size
=
inputDesc
[
1
].
dims
.
d
[
0
];
int
channels
=
inputDesc
[
1
].
dims
.
d
[
1
];
const
int
kMaxFeatMap
=
10
;
int
heights
[
kMaxFeatMap
];
int
widths
[
kMaxFeatMap
];
float
strides
[
kMaxFeatMap
];
int
num_feats
=
mFeatmapStrides
.
size
();
for
(
int
i
=
0
;
i
<
num_feats
;
++
i
)
{
heights
[
i
]
=
inputDesc
[
i
+
1
].
dims
.
d
[
2
];
widths
[
i
]
=
inputDesc
[
i
+
1
].
dims
.
d
[
3
];
strides
[
i
]
=
mFeatmapStrides
[
i
];
}
const
void
*
rois
=
inputs
[
0
];
const
void
*
const
*
feats
=
inputs
+
1
;
multi_level_rotated_roi_align
<
float
>
((
float
*
)
outputs
[
0
],
(
const
float
*
)
rois
,
num_rois
,
feats
,
num_feats
,
batch_size
,
channels
,
&
heights
[
0
],
&
widths
[
0
],
&
strides
[
0
],
mAlignedHeight
,
mAlignedWidth
,
mClockwise
,
mSampleNum
,
mRoiScaleFactor
,
mFinestScale
,
mAligned
,
stream
);
return
0
;
}
nvinfer1
::
DataType
TRTMultiLevelRotatedRoiAlign
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
{
return
nvinfer1
::
DataType
::
kFLOAT
;
}
// IPluginV2 Methods
const
char
*
TRTMultiLevelRotatedRoiAlign
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
TRTMultiLevelRotatedRoiAlign
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
int
TRTMultiLevelRotatedRoiAlign
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
size_t
TRTMultiLevelRotatedRoiAlign
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
serialized_size
(
mFeatmapStrides
)
+
serialized_size
(
mAlignedHeight
)
+
serialized_size
(
mAlignedWidth
)
+
serialized_size
(
mClockwise
)
+
serialized_size
(
mSampleNum
)
+
serialized_size
(
mRoiScaleFactor
)
+
serialized_size
(
mFinestScale
)
+
serialized_size
(
mAligned
);
}
void
TRTMultiLevelRotatedRoiAlign
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
serialize_value
(
&
buffer
,
mAlignedHeight
);
serialize_value
(
&
buffer
,
mAlignedWidth
);
serialize_value
(
&
buffer
,
mClockwise
);
serialize_value
(
&
buffer
,
mSampleNum
);
serialize_value
(
&
buffer
,
mRoiScaleFactor
);
serialize_value
(
&
buffer
,
mFinestScale
);
serialize_value
(
&
buffer
,
mAligned
);
serialize_value
(
&
buffer
,
mFeatmapStrides
);
}
TRTMultiLevelRotatedRoiAlignCreator
::
TRTMultiLevelRotatedRoiAlignCreator
()
{
mPluginAttributes
=
std
::
vector
<
nvinfer1
::
PluginField
>
(
{
nvinfer1
::
PluginField
(
"output_height"
),
nvinfer1
::
PluginField
(
"output_width"
),
nvinfer1
::
PluginField
(
"clockwise"
),
nvinfer1
::
PluginField
(
"sampling_ratio"
),
nvinfer1
::
PluginField
(
"featmap_strides"
),
nvinfer1
::
PluginField
(
"roi_scale_factor"
),
nvinfer1
::
PluginField
(
"finest_scale"
),
nvinfer1
::
PluginField
(
"aligned"
)});
mFC
.
nbFields
=
mPluginAttributes
.
size
();
mFC
.
fields
=
mPluginAttributes
.
data
();
}
const
char
*
TRTMultiLevelRotatedRoiAlignCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
TRTMultiLevelRotatedRoiAlignCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
nvinfer1
::
IPluginV2
*
TRTMultiLevelRotatedRoiAlignCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
int
alignedHeight
=
7
;
int
alignedWidth
=
7
;
int
clockwise
=
0
;
int
sampleNum
=
2
;
std
::
vector
<
float
>
featmapStrides
;
float
roiScaleFactor
=
-
1
;
int
finestScale
=
56
;
bool
aligned
=
false
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
i
++
)
{
if
(
fc
->
fields
[
i
].
data
==
nullptr
)
{
continue
;
}
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"output_height"
)
==
0
)
{
alignedHeight
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
else
if
(
field_name
.
compare
(
"output_width"
)
==
0
)
{
alignedWidth
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
else
if
(
field_name
.
compare
(
"clockwise"
)
==
0
)
{
clockwise
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
else
if
(
field_name
.
compare
(
"sampling_ratio"
)
==
0
)
{
sampleNum
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
else
if
(
field_name
.
compare
(
"roi_scale_factor"
)
==
0
)
{
roiScaleFactor
=
static_cast
<
const
float
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
else
if
(
field_name
.
compare
(
"finest_scale"
)
==
0
)
{
finestScale
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
else
if
(
field_name
.
compare
(
"featmap_strides"
)
==
0
)
{
int
data_size
=
(
fc
->
fields
[
i
].
length
);
const
float
*
data_start
=
static_cast
<
const
float
*>
(
fc
->
fields
[
i
].
data
);
featmapStrides
=
std
::
vector
<
float
>
(
data_start
,
data_start
+
data_size
);
}
else
if
(
field_name
.
compare
(
"aligned"
)
==
0
)
{
int
aligned_int
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
aligned
=
aligned_int
!=
0
;
}
}
ASSERT
(
featmapStrides
.
size
()
!=
0
);
TRTMultiLevelRotatedRoiAlign
*
plugin
=
new
TRTMultiLevelRotatedRoiAlign
(
name
,
alignedHeight
,
alignedWidth
,
clockwise
,
sampleNum
,
featmapStrides
,
roiScaleFactor
,
finestScale
,
aligned
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
TRTMultiLevelRotatedRoiAlignCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
{
auto
plugin
=
new
TRTMultiLevelRotatedRoiAlign
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
REGISTER_TENSORRT_PLUGIN
(
TRTMultiLevelRotatedRoiAlignCreator
);
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP
#define TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace
mmdeploy
{
class
TRTMultiLevelRotatedRoiAlign
:
public
TRTPluginBase
{
public:
TRTMultiLevelRotatedRoiAlign
(
const
std
::
string
&
name
,
int
alignedHeight
,
int
alignedWidth
,
int
clockwise
,
int
sampleNum
,
const
std
::
vector
<
float
>
&
featmapStrides
,
float
roiScaleFactor
=
-
1
,
int
finestScale
=
56
,
bool
aligned
=
false
);
TRTMultiLevelRotatedRoiAlign
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
TRTMultiLevelRotatedRoiAlign
()
=
delete
;
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
private:
int
mAlignedHeight
;
int
mAlignedWidth
;
int
mClockwise
;
int
mSampleNum
;
std
::
vector
<
float
>
mFeatmapStrides
;
float
mRoiScaleFactor
;
int
mFinestScale
;
bool
mAligned
;
};
class
TRTMultiLevelRotatedRoiAlignCreator
:
public
TRTPluginCreatorBase
{
public:
TRTMultiLevelRotatedRoiAlignCreator
();
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
override
;
};
}
// namespace mmdeploy
#endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include <float.h>
#include <stdio.h>
#include <algorithm>
#include <cmath>
#include "common_cuda_helper.hpp"
#include "trt_multi_level_rotated_roi_align_kernel.hpp"
#include "trt_plugin_helper.hpp"
const
int
kMAX_FEATMAP_SIZE
=
10
;
struct
FeatData
{
const
void
*
data
[
kMAX_FEATMAP_SIZE
];
int
batch_size
;
int
channels
;
int
h
[
kMAX_FEATMAP_SIZE
];
int
w
[
kMAX_FEATMAP_SIZE
];
float
spatial_scale
[
kMAX_FEATMAP_SIZE
];
int
num_featmap
;
};
template
<
typename
scalar_t
,
bool
aligned
>
__device__
scalar_t
roi_align_single
(
const
scalar_t
*
__restrict__
bottom_data
,
const
int
roi_batch_ind
,
scalar_t
roi_center_w
,
scalar_t
roi_center_h
,
scalar_t
roi_width
,
scalar_t
roi_height
,
scalar_t
theta
,
const
scalar_t
spatial_scale
,
const
int
pw
,
const
int
ph
,
const
int
c
,
const
int
sample_num
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pooled_height
,
const
int
pooled_width
)
{
// Force malformed ROIs to be 1x1
roi_width
=
max
(
roi_width
,
(
scalar_t
)
1.
);
roi_height
=
max
(
roi_height
,
(
scalar_t
)
1.
);
const
scalar_t
bin_size_h
=
roi_height
/
scalar_t
(
pooled_height
);
const
scalar_t
bin_size_w
=
roi_width
/
scalar_t
(
pooled_width
);
const
scalar_t
*
offset_bottom_data
=
bottom_data
+
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
const
int
roi_bin_grid_h
=
(
sample_num
>
0
)
?
sample_num
:
ceil
(
roi_height
/
pooled_height
);
const
int
roi_bin_grid_w
=
(
sample_num
>
0
)
?
sample_num
:
ceil
(
roi_width
/
pooled_width
);
const
scalar_t
roi_start_h
=
-
roi_height
/
scalar_t
(
2.0
);
const
scalar_t
roi_start_w
=
-
roi_width
/
scalar_t
(
2.0
);
const
scalar_t
cosscalar_theta
=
cos
(
theta
);
const
scalar_t
sinscalar_theta
=
sin
(
theta
);
// We do average (integral) pooling inside a bin
const
scalar_t
count
=
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
);
// e.g. = 4
scalar_t
output_val
=
0.
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
// e.g., iy = 0, 1
const
scalar_t
yy
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
scalar_t
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
scalar_t
>
(
roi_bin_grid_h
);
// e.g., 0.5, 1.5
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
scalar_t
xx
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
scalar_t
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
scalar_t
>
(
roi_bin_grid_w
);
// Rotate by theta (counterclockwise) around the center and translate
scalar_t
y
=
yy
*
cosscalar_theta
-
xx
*
sinscalar_theta
+
roi_center_h
;
scalar_t
x
=
yy
*
sinscalar_theta
+
xx
*
cosscalar_theta
+
roi_center_w
;
scalar_t
val
=
bilinear_interpolate
<
scalar_t
>
(
offset_bottom_data
,
height
,
width
,
y
,
x
);
output_val
+=
val
;
}
}
return
output_val
/
count
;
}
template
<
typename
scalar_t
,
bool
aligned
>
__global__
void
rotated_roi_extractor_kernel
(
scalar_t
*
__restrict__
output
,
const
scalar_t
*
__restrict__
bottom_rois
,
FeatData
feat_data
,
const
int
clockwise
,
const
int
sample_num
,
const
float
roi_scale_factor
,
const
int
finest_scale
,
const
int
pooled_height
,
const
int
pooled_width
,
int
nThreads
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nThreads
)
{
const
int
channels
=
feat_data
.
channels
;
int
tmp_index
=
index
;
const
int
pw
=
tmp_index
%
pooled_width
;
tmp_index
/=
pooled_width
;
const
int
ph
=
tmp_index
%
pooled_height
;
tmp_index
/=
pooled_height
;
const
int
c
=
tmp_index
%
channels
;
const
int
n
=
tmp_index
/
channels
;
const
scalar_t
*
offset_bottom_rois
=
bottom_rois
+
n
*
6
;
scalar_t
roi_offset_x0
=
offset_bottom_rois
[
1
];
scalar_t
roi_offset_y0
=
offset_bottom_rois
[
2
];
scalar_t
roi_offset_width
=
offset_bottom_rois
[
3
];
scalar_t
roi_offset_height
=
offset_bottom_rois
[
4
];
scalar_t
theta
=
offset_bottom_rois
[
5
];
const
scalar_t
scale
=
sqrtf
(
roi_offset_width
*
roi_offset_height
);
const
int
target_lvls
=
min
(
feat_data
.
num_featmap
-
1
,
max
(
0
,
int
(
floorf
(
log2f
(
scale
/
(
scalar_t
)(
finest_scale
)
+
1e-6
)))));
if
(
roi_scale_factor
>
0.
)
{
roi_offset_width
=
roi_offset_width
*
roi_scale_factor
;
roi_offset_height
=
roi_offset_height
*
roi_scale_factor
;
}
const
scalar_t
spatial_scale
=
(
scalar_t
)
feat_data
.
spatial_scale
[
target_lvls
];
const
int
height
=
feat_data
.
h
[
target_lvls
];
const
int
width
=
feat_data
.
w
[
target_lvls
];
const
scalar_t
*
bottom_data
=
(
scalar_t
*
)
feat_data
.
data
[
target_lvls
];
const
int
roi_batch_ind
=
offset_bottom_rois
[
0
];
const
scalar_t
offset
=
aligned
?
(
scalar_t
)
-
0.5
:
(
scalar_t
)
0.0
;
const
scalar_t
roi_center_w
=
fma
(
roi_offset_x0
,
spatial_scale
,
offset
);
const
scalar_t
roi_center_h
=
fma
(
roi_offset_y0
,
spatial_scale
,
offset
);
const
scalar_t
roi_width
=
roi_offset_width
*
spatial_scale
;
const
scalar_t
roi_height
=
roi_offset_height
*
spatial_scale
;
theta
=
clockwise
>
0
?
-
theta
:
theta
;
const
scalar_t
output_val
=
roi_align_single
<
scalar_t
,
aligned
>
(
bottom_data
,
roi_batch_ind
,
roi_center_w
,
roi_center_h
,
roi_width
,
roi_height
,
theta
,
spatial_scale
,
pw
,
ph
,
c
,
sample_num
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
);
output
[
index
]
=
output_val
;
}
}
template
<
typename
T
>
void
multi_level_rotated_roi_align
(
T
*
output
,
const
T
*
rois
,
int
num_rois
,
const
void
*
const
*
feats
,
int
num_feats
,
int
n
,
int
c
,
int
*
h
,
int
*
w
,
float
*
strides
,
int
aligned_height
,
int
aligned_width
,
int
clockwise
,
int
sample_num
,
float
roi_scale_factor
,
int
finest_scale
,
bool
aligned
,
cudaStream_t
stream
)
{
FeatData
feat_data
;
feat_data
.
batch_size
=
n
;
feat_data
.
channels
=
c
;
feat_data
.
num_featmap
=
num_feats
;
for
(
int
i
=
0
;
i
<
num_feats
;
++
i
)
{
feat_data
.
data
[
i
]
=
feats
[
i
];
feat_data
.
h
[
i
]
=
h
[
i
];
feat_data
.
w
[
i
]
=
w
[
i
];
feat_data
.
spatial_scale
[
i
]
=
1.
/
float
(
strides
[
i
]);
}
int
nThreads
=
num_rois
*
c
*
aligned_height
*
aligned_width
;
if
(
aligned
)
{
rotated_roi_extractor_kernel
<
T
,
true
><<<
GET_BLOCKS
(
nThreads
),
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
output
,
rois
,
feat_data
,
clockwise
,
sample_num
,
roi_scale_factor
,
finest_scale
,
aligned_height
,
aligned_width
,
nThreads
);
}
else
{
rotated_roi_extractor_kernel
<
T
,
false
><<<
GET_BLOCKS
(
nThreads
),
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
output
,
rois
,
feat_data
,
clockwise
,
sample_num
,
roi_scale_factor
,
finest_scale
,
aligned_height
,
aligned_width
,
nThreads
);
}
}
template
void
multi_level_rotated_roi_align
<
float
>(
float
*
output
,
const
float
*
rois
,
int
num_rois
,
const
void
*
const
*
feats
,
int
num_feats
,
int
n
,
int
c
,
int
*
h
,
int
*
w
,
float
*
strides
,
int
aligned_height
,
int
aligned_width
,
int
clockwise
,
int
sample_num
,
float
roi_scale_factor
,
int
finest_scale
,
bool
aligned
,
cudaStream_t
stream
);
csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP
#define TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP
#include <cuda_runtime.h>
template
<
typename
T
>
void
multi_level_rotated_roi_align
(
T
*
output
,
const
T
*
rois
,
int
num_rois
,
const
void
*
const
*
feats
,
int
num_feats
,
int
n
,
int
c
,
int
*
h
,
int
*
w
,
float
*
strides
,
int
aligned_height
,
int
aligned_width
,
int
clockwise
,
int
sample_num
,
float
roi_scale_factor
,
int
finest_scale
,
bool
aligned
,
cudaStream_t
stream
);
#endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_ms_deform_attn.hpp"
#include <assert.h>
#include <chrono>
#include "trt_ms_deform_attn_kernel.hpp"
#include "trt_serialize.hpp"
using
namespace
nvinfer1
;
namespace
mmdeploy
{
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
PLUGIN_NAME
{
"MMCVMultiScaleDeformableAttention"
};
}
// namespace
MultiScaleDeformableAttnPluginDynamic
::
MultiScaleDeformableAttnPluginDynamic
(
const
std
::
string
&
name
)
:
TRTPluginBase
(
name
)
{}
MultiScaleDeformableAttnPluginDynamic
::
MultiScaleDeformableAttnPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
TRTPluginBase
(
name
)
{}
MultiScaleDeformableAttnPluginDynamic
::~
MultiScaleDeformableAttnPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
MultiScaleDeformableAttnPluginDynamic
::
clone
()
const
TRT_NOEXCEPT
{
MultiScaleDeformableAttnPluginDynamic
*
plugin
=
new
MultiScaleDeformableAttnPluginDynamic
(
mLayerName
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
MultiScaleDeformableAttnPluginDynamic
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
3
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
inputs
[
3
].
d
[
1
];
ret
.
d
[
2
]
=
exprBuilder
.
operation
(
DimensionOperation
::
kPROD
,
*
inputs
[
0
].
d
[
2
],
*
inputs
[
0
].
d
[
3
]);
return
ret
;
}
bool
MultiScaleDeformableAttnPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
if
(
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
if
((
pos
==
1
)
||
(
pos
==
2
))
{
return
(
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
);
}
else
{
return
((
ioDesc
[
pos
].
type
==
ioDesc
[
0
].
type
)
&&
((
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
||
(
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
)));
}
}
else
{
return
false
;
}
}
void
MultiScaleDeformableAttnPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
size_t
MultiScaleDeformableAttnPluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
int
MultiScaleDeformableAttnPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
int32_t
const
batch
=
inputDesc
[
0
].
dims
.
d
[
0
];
int32_t
spatial_size
=
inputDesc
[
0
].
dims
.
d
[
1
];
int32_t
num_heads
=
inputDesc
[
0
].
dims
.
d
[
2
];
int32_t
channels
=
inputDesc
[
0
].
dims
.
d
[
3
];
int32_t
num_levels
=
inputDesc
[
1
].
dims
.
d
[
0
];
int32_t
num_query
=
inputDesc
[
3
].
dims
.
d
[
1
];
int32_t
num_point
=
inputDesc
[
3
].
dims
.
d
[
4
];
int32_t
rc
=
0
;
if
(
inputDesc
[
0
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
float
const
*
value
=
static_cast
<
float
const
*>
(
inputs
[
0
]);
int32_t
const
*
spatialShapes
=
static_cast
<
int32_t
const
*>
(
inputs
[
1
]);
int32_t
const
*
levelStartIndex
=
static_cast
<
int32_t
const
*>
(
inputs
[
2
]);
float
const
*
samplingLoc
=
static_cast
<
float
const
*>
(
inputs
[
3
]);
float
const
*
attnWeight
=
static_cast
<
float
const
*>
(
inputs
[
4
]);
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
rc
=
ms_deform_attn_cuda_forward
(
value
,
spatialShapes
,
levelStartIndex
,
samplingLoc
,
attnWeight
,
output
,
batch
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
stream
);
}
else
if
(
inputDesc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
)
{
const
__half
*
value
=
static_cast
<
const
__half
*>
(
inputs
[
0
]);
int32_t
const
*
spatialShapes
=
static_cast
<
int32_t
const
*>
(
inputs
[
1
]);
int32_t
const
*
levelStartIndex
=
static_cast
<
int32_t
const
*>
(
inputs
[
2
]);
const
__half
*
samplingLoc
=
static_cast
<
const
__half
*>
(
inputs
[
3
]);
const
__half
*
attnWeight
=
static_cast
<
const
__half
*>
(
inputs
[
4
]);
__half
*
output
=
static_cast
<
__half
*>
(
outputs
[
0
]);
rc
=
ms_deform_attn_cuda_forward
(
value
,
spatialShapes
,
levelStartIndex
,
samplingLoc
,
attnWeight
,
output
,
batch
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
stream
);
}
return
rc
;
}
nvinfer1
::
DataType
MultiScaleDeformableAttnPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
{
return
inputTypes
[
0
];
}
// IPluginV2 Methods
const
char
*
MultiScaleDeformableAttnPluginDynamic
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
MultiScaleDeformableAttnPluginDynamic
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
int
MultiScaleDeformableAttnPluginDynamic
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
size_t
MultiScaleDeformableAttnPluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
0
;
}
void
MultiScaleDeformableAttnPluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{}
void
MultiScaleDeformableAttnPluginDynamic
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
MultiScaleDeformableAttnPluginDynamic
::
detachFromContext
()
TRT_NOEXCEPT
{}
////////////////////// creator /////////////////////////////
MultiScaleDeformableAttnPluginDynamicCreator
::
MultiScaleDeformableAttnPluginDynamicCreator
()
{
mPluginAttributes
.
clear
();
mFC
.
nbFields
=
mPluginAttributes
.
size
();
mFC
.
fields
=
mPluginAttributes
.
data
();
}
const
char
*
MultiScaleDeformableAttnPluginDynamicCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
MultiScaleDeformableAttnPluginDynamicCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
nvinfer1
::
IPluginV2
*
MultiScaleDeformableAttnPluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
MultiScaleDeformableAttnPluginDynamic
*
plugin
=
new
MultiScaleDeformableAttnPluginDynamic
(
name
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
MultiScaleDeformableAttnPluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
{
auto
plugin
=
new
MultiScaleDeformableAttnPluginDynamic
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
REGISTER_TENSORRT_PLUGIN
(
MultiScaleDeformableAttnPluginDynamicCreator
);
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_MS_DEFORM_ATTN_HPP
#define TRT_MS_DEFORM_ATTN_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace
mmdeploy
{
class
MultiScaleDeformableAttnPluginDynamic
:
public
TRTPluginBase
{
public:
MultiScaleDeformableAttnPluginDynamic
(
const
std
::
string
&
name
);
MultiScaleDeformableAttnPluginDynamic
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
MultiScaleDeformableAttnPluginDynamic
();
~
MultiScaleDeformableAttnPluginDynamic
()
TRT_NOEXCEPT
override
;
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
void
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
override
;
void
detachFromContext
()
TRT_NOEXCEPT
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
};
class
MultiScaleDeformableAttnPluginDynamicCreator
:
public
TRTPluginCreatorBase
{
public:
MultiScaleDeformableAttnPluginDynamicCreator
();
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
override
;
};
}
// namespace mmdeploy
#endif // TRT_MS_DEFORM_ATTN_HPP
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved
#include <assert.h>
#include <cuda_fp16.h>
#include "common_cuda_helper.hpp"
#include "trt_ms_deform_attn_kernel.cuh"
#include "trt_ms_deform_attn_kernel.hpp"
#include "trt_plugin_helper.hpp"
template
<
typename
scalar_t
>
void
ms_deformable_im2col_cuda
(
cudaStream_t
stream
,
scalar_t
const
*
dataValue
,
int32_t
const
*
dataSpatialShapes
,
int32_t
const
*
dataLevelStartIndex
,
scalar_t
const
*
dataSamplingLoc
,
scalar_t
const
*
dataAttnWeight
,
int32_t
const
batchSize
,
int32_t
const
spatialSize
,
int32_t
const
numHeads
,
int32_t
const
channels
,
int32_t
const
numLevels
,
int32_t
const
numQuery
,
int32_t
const
numPoint
,
scalar_t
*
dataCol
)
{
int32_t
const
numKernels
=
batchSize
*
numQuery
*
numHeads
*
channels
;
int32_t
const
numActualKernels
=
batchSize
*
numQuery
*
numHeads
*
channels
;
ms_deformable_im2col_gpu_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
numActualKernels
),
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
numKernels
,
dataValue
,
dataSpatialShapes
,
dataLevelStartIndex
,
dataSamplingLoc
,
dataAttnWeight
,
batchSize
,
spatialSize
,
numHeads
,
channels
,
numLevels
,
numQuery
,
numPoint
,
dataCol
);
}
template
<
typename
scalar_t
>
int32_t
ms_deform_attn_cuda_forward
(
const
scalar_t
*
value
,
const
int32_t
*
spatialShapes
,
const
int32_t
*
levelStartIndex
,
const
scalar_t
*
samplingLoc
,
const
scalar_t
*
attnWeight
,
scalar_t
*
output
,
int32_t
batch
,
int32_t
mSpatialSize
,
int32_t
mNumHeads
,
int32_t
mChannels
,
int32_t
mNumLevels
,
int32_t
mNumQuery
,
int32_t
mNumPoint
,
cudaStream_t
stream
)
{
auto
perValueSize
=
mSpatialSize
*
mNumHeads
*
mChannels
;
auto
perSampleLocSize
=
mNumQuery
*
mNumHeads
*
mNumLevels
*
mNumPoint
*
2
;
auto
perAttnWeightSize
=
mNumQuery
*
mNumHeads
*
mNumLevels
*
mNumPoint
;
auto
perOutputSize
=
mNumQuery
*
mNumHeads
*
mChannels
;
int32_t
mIm2colStep
=
batch
;
for
(
int32_t
n
=
0
;
n
<
batch
/
mIm2colStep
;
++
n
)
{
auto
columns
=
output
+
n
*
mIm2colStep
*
perOutputSize
;
ms_deformable_im2col_cuda
<
scalar_t
>
(
stream
,
value
+
n
*
mIm2colStep
*
perValueSize
,
spatialShapes
,
levelStartIndex
,
samplingLoc
+
n
*
mIm2colStep
*
perSampleLocSize
,
attnWeight
+
n
*
mIm2colStep
*
perAttnWeightSize
,
mIm2colStep
,
mSpatialSize
,
mNumHeads
,
mChannels
,
mNumLevels
,
mNumQuery
,
mNumPoint
,
columns
);
}
return
0
;
}
template
int32_t
ms_deform_attn_cuda_forward
<
float
>(
const
float
*
value
,
const
int32_t
*
spatialShapes
,
const
int32_t
*
levelStartIndex
,
const
float
*
samplingLoc
,
const
float
*
attnWeight
,
float
*
output
,
int32_t
batch
,
int32_t
mSpatialSize
,
int32_t
mNumHeads
,
int32_t
mChannels
,
int32_t
mNumLevels
,
int32_t
mNumQuery
,
int32_t
mNumPoint
,
cudaStream_t
stream
);
template
int32_t
ms_deform_attn_cuda_forward
<
__half
>(
const
__half
*
value
,
const
int32_t
*
spatialShapes
,
const
int32_t
*
levelStartIndex
,
const
__half
*
samplingLoc
,
const
__half
*
attnWeight
,
__half
*
output
,
int32_t
batch
,
int32_t
mSpatialSize
,
int32_t
mNumHeads
,
int32_t
mChannels
,
int32_t
mNumLevels
,
int32_t
mNumQuery
,
int32_t
mNumPoint
,
cudaStream_t
stream
);
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh
0 → 100644
View file @
546b4279
// modify from:
// https://github.com/NVIDIA/TensorRT/blob/main/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableIm2ColCuda.cuh
#include <cuda_fp16.h>
#include "common_cuda_helper.hpp"
template
<
typename
scalar_t
>
__device__
scalar_t
ms_deform_attn_im2col_bilinear
(
const
scalar_t
*&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
)
{
const
int
h_low
=
floorf
(
h
);
const
int
w_low
=
floorf
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
}
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
>
__device__
__half
ms_deform_attn_im2col_bilinear
<
__half
>
(
const
__half
*&
bottomData
,
int32_t
const
&
height
,
int32_t
const
&
width
,
int32_t
const
&
nHeads
,
int32_t
const
&
channels
,
const
__half
&
h
,
const
__half
&
w
,
int32_t
const
&
m
,
int32_t
const
&
c
)
{
int32_t
const
hLow
=
__half2int_rd
(
h
);
int32_t
const
wLow
=
__half2int_rd
(
w
);
int32_t
const
hHigh
=
hLow
+
1
;
int32_t
const
wHigh
=
wLow
+
1
;
const
__half
kZERO
=
__int2half_rz
(
0
);
const
__half
one
=
__int2half_rz
(
1
);
#if __CUDA_ARCH__ >= 530
const
__half
lh
=
__hsub
(
h
,
__int2half_rd
(
hLow
));
const
__half
lw
=
__hsub
(
w
,
__int2half_rd
(
wLow
));
const
__half
hh
=
__hsub
(
one
,
lh
),
hw
=
__hsub
(
one
,
lw
);
#else
const
__half
lh
=
__float2half
(
__half2float
(
h
)
-
hLow
);
const
__half
lw
=
__float2half
(
__half2float
(
w
)
-
wLow
);
const
__half
hh
=
__float2half
(
__half2float
(
one
)
-
__half2float
(
lh
));
const
__half
hw
=
__float2half
(
__half2float
(
one
)
-
__half2float
(
lw
));
#endif
int32_t
const
wStride
=
nHeads
*
channels
;
int32_t
const
hStride
=
width
*
wStride
;
int32_t
const
hLowPtrOffset
=
hLow
*
hStride
;
int32_t
const
hHighPtrOffset
=
hLowPtrOffset
+
hStride
;
int32_t
const
wLowPtrOffset
=
wLow
*
wStride
;
int32_t
const
wHighPtrOffset
=
wLowPtrOffset
+
wStride
;
int32_t
const
basePtr
=
m
*
channels
+
c
;
__half
v1
=
kZERO
;
if
(
hLow
>=
0
&&
wLow
>=
0
)
{
int32_t
const
ptr1
=
hLowPtrOffset
+
wLowPtrOffset
+
basePtr
;
v1
=
bottomData
[
ptr1
];
}
__half
v2
=
kZERO
;
if
(
hLow
>=
0
&&
wHigh
<=
width
-
1
)
{
int32_t
const
ptr2
=
hLowPtrOffset
+
wHighPtrOffset
+
basePtr
;
v2
=
bottomData
[
ptr2
];
}
__half
v3
=
kZERO
;
if
(
hHigh
<=
height
-
1
&&
wLow
>=
0
)
{
int32_t
const
ptr3
=
hHighPtrOffset
+
wLowPtrOffset
+
basePtr
;
v3
=
bottomData
[
ptr3
];
}
__half
v4
=
kZERO
;
if
(
hHigh
<=
height
-
1
&&
wHigh
<=
width
-
1
)
{
int32_t
const
ptr4
=
hHighPtrOffset
+
wHighPtrOffset
+
basePtr
;
v4
=
bottomData
[
ptr4
];
}
#if __CUDA_ARCH__ >= 530
__half
w1
=
__hmul
(
__hmul
(
hh
,
hw
),
v1
);
__half
w2
=
__hmul
(
__hmul
(
hh
,
lw
),
v2
);
__half
w3
=
__hmul
(
__hmul
(
lh
,
hw
),
v3
);
__half
w4
=
__hmul
(
__hmul
(
lh
,
lw
),
v4
);
w1
=
__hadd
(
w1
,
w2
);
w3
=
__hadd
(
w3
,
w4
);
const
__half
val
=
__hadd
(
w1
,
w3
);
#else
__half
w1
=
__float2half
((
__half2float
(
hh
)
*
__half2float
(
hw
))
*
__half2float
(
v1
));
__half
w2
=
__float2half
((
__half2float
(
hh
)
*
__half2float
(
lw
))
*
__half2float
(
v2
));
__half
w3
=
__float2half
((
__half2float
(
lh
)
*
__half2float
(
hw
))
*
__half2float
(
v3
));
__half
w4
=
__float2half
((
__half2float
(
lh
)
*
__half2float
(
lw
))
*
__half2float
(
v4
));
w1
=
__float2half
(
__half2float
(
w1
)
+
__half2float
(
w2
));
w3
=
__float2half
(
__half2float
(
w3
)
+
__half2float
(
w4
));
const
__half
val
=
__float2half
(
__half2float
(
w1
)
+
__half2float
(
w3
));
#endif
return
val
;
}
#if 1
template
<
typename
scalar_t
>
__global__
void
ms_deformable_im2col_gpu_kernel
(
int32_t
const
n
,
scalar_t
const
*
dataValue
,
int32_t
const
*
dataSpatialShapes
,
int32_t
const
*
dataLevelStartIndex
,
scalar_t
const
*
dataSamplingLoc
,
scalar_t
const
*
dataAttnWeight
,
int32_t
const
batchSize
,
int32_t
const
spatialSize
,
int32_t
const
numHeads
,
int32_t
const
channels
,
int32_t
const
numLevels
,
int32_t
const
numQuery
,
int32_t
const
numPoint
,
scalar_t
*
dataCol
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int32_t
_temp
=
index
;
int32_t
const
cCol
=
_temp
%
channels
;
_temp
/=
channels
;
int32_t
const
samplingIndex
=
_temp
;
int32_t
const
mCol
=
_temp
%
numHeads
;
_temp
/=
numHeads
;
_temp
/=
numQuery
;
int32_t
const
bCol
=
_temp
;
scalar_t
*
dataColPtr
=
dataCol
+
index
;
int32_t
dataWeightPtr
=
samplingIndex
*
numLevels
*
numPoint
;
int32_t
dataLocWPtr
=
dataWeightPtr
<<
1
;
int32_t
const
qidStride
=
numHeads
*
channels
;
int32_t
const
dataValuePtrInitOffset
=
bCol
*
spatialSize
*
qidStride
;
scalar_t
col
=
0
;
for
(
int32_t
lCol
=
0
;
lCol
<
numLevels
;
++
lCol
)
{
int32_t
const
levelStartId
=
dataLevelStartIndex
[
lCol
];
int32_t
const
spatialHPtr
=
lCol
<<
1
;
int32_t
const
spatialH
=
dataSpatialShapes
[
spatialHPtr
];
int32_t
const
spatialW
=
dataSpatialShapes
[
spatialHPtr
+
1
];
scalar_t
const
*
dataValuePtr
=
dataValue
+
(
dataValuePtrInitOffset
+
levelStartId
*
qidStride
);
for
(
int32_t
pCol
=
0
;
pCol
<
numPoint
;
++
pCol
)
{
scalar_t
const
locW
=
dataSamplingLoc
[
dataLocWPtr
];
scalar_t
const
locH
=
dataSamplingLoc
[
dataLocWPtr
+
1
];
scalar_t
const
weight
=
dataAttnWeight
[
dataWeightPtr
];
scalar_t
const
hIm
=
locH
*
spatialH
-
0.5
;
scalar_t
const
wIm
=
locW
*
spatialW
-
0.5
;
if
(
hIm
>
-
1
&&
wIm
>
-
1
&&
hIm
<
spatialH
&&
wIm
<
spatialW
)
{
col
+=
ms_deform_attn_im2col_bilinear
(
dataValuePtr
,
spatialH
,
spatialW
,
numHeads
,
channels
,
hIm
,
wIm
,
mCol
,
cCol
)
*
weight
;
}
dataWeightPtr
+=
1
;
dataLocWPtr
+=
2
;
}
}
*
dataColPtr
=
col
;
}
}
template
<
>
__global__
void
ms_deformable_im2col_gpu_kernel
<
__half
>
(
int32_t
const
n
,
const
__half
*
dataValue
,
int32_t
const
*
dataSpatialShapes
,
int32_t
const
*
dataLevelStartIndex
,
const
__half
*
dataSamplingLoc
,
const
__half
*
dataAttnWeight
,
int32_t
const
batchSize
,
int32_t
const
spatialSize
,
int32_t
const
numHeads
,
int32_t
const
channels
,
int32_t
const
numLevels
,
int32_t
const
numQuery
,
int32_t
const
numPoint
,
__half
*
dataCol
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
n
)
{
int32_t
_temp
=
index
;
int32_t
const
cCol
=
_temp
%
channels
;
_temp
/=
channels
;
int32_t
const
samplingIndex
=
_temp
;
int32_t
const
mCol
=
_temp
%
numHeads
;
_temp
/=
numHeads
;
_temp
/=
numQuery
;
int32_t
const
bCol
=
_temp
;
__half
*
dataColPtr
=
dataCol
+
index
;
int32_t
dataWeightPtr
=
samplingIndex
*
numLevels
*
numPoint
;
int32_t
dataLocWPtr
=
dataWeightPtr
<<
1
;
int32_t
const
qidStride
=
numHeads
*
channels
;
int32_t
const
dataValuePtrInitOffset
=
bCol
*
spatialSize
*
qidStride
;
const
__half
kZERO_POINT_FIVE
=
__float2half
(
0.5
f
);
const
__half
kMINUS_ONE
=
__float2half
(
-
1.0
f
);
const
__half
kZERO
=
__int2half_rz
(
0
);
__half
tpVal
=
kZERO
;
__half
col
=
kZERO
;
for
(
int32_t
lCol
=
0
;
lCol
<
numLevels
;
++
lCol
)
{
int32_t
const
levelStartId
=
dataLevelStartIndex
[
lCol
];
int32_t
const
spatialHPtr
=
lCol
<<
1
;
int32_t
const
spatialH
=
dataSpatialShapes
[
spatialHPtr
];
int32_t
const
spatialW
=
dataSpatialShapes
[
spatialHPtr
+
1
];
const
__half
spatialHHalf
=
__int2half_rd
(
spatialH
);
const
__half
spatialWHalf
=
__int2half_rd
(
spatialW
);
const
__half
*
dataValuePtr
=
dataValue
+
(
dataValuePtrInitOffset
+
levelStartId
*
qidStride
);
for
(
int32_t
pCol
=
0
;
pCol
<
numPoint
;
++
pCol
)
{
const
__half
locW
=
dataSamplingLoc
[
dataLocWPtr
];
const
__half
locH
=
dataSamplingLoc
[
dataLocWPtr
+
1
];
const
__half
weight
=
dataAttnWeight
[
dataWeightPtr
];
#if __CUDA_ARCH__ >= 530
const
__half
hIm
=
__hsub
(
__hmul
(
locH
,
spatialHHalf
),
kZERO_POINT_FIVE
);
const
__half
wIm
=
__hsub
(
__hmul
(
locW
,
spatialWHalf
),
kZERO_POINT_FIVE
);
if
(
__hgt
(
hIm
,
kMINUS_ONE
)
&&
__hgt
(
wIm
,
kMINUS_ONE
)
&&
__hlt
(
hIm
,
spatialHHalf
)
&&
__hlt
(
wIm
,
spatialWHalf
))
{
tpVal
=
ms_deform_attn_im2col_bilinear
(
dataValuePtr
,
spatialH
,
spatialW
,
numHeads
,
channels
,
hIm
,
wIm
,
mCol
,
cCol
);
col
=
__hadd
(
col
,
__hmul
(
tpVal
,
weight
));
}
#else
const
__half
hIm
=
__float2half
(
__half2float
(
locH
)
*
__half2float
(
spatialHHalf
)
-
__half2float
(
kZERO_POINT_FIVE
));
const
__half
wIm
=
__float2half
(
__half2float
(
locW
)
*
__half2float
(
spatialWHalf
)
-
__half2float
(
kZERO_POINT_FIVE
));
if
((
__half2float
(
hIm
)
>
__half2float
(
kMINUS_ONE
))
&&
(
__half2float
(
wIm
)
>
__half2float
(
kMINUS_ONE
))
&&
(
__half2float
(
hIm
)
<
__half2float
(
spatialHHalf
))
&&
(
__half2float
(
wIm
)
<
__half2float
(
spatialWHalf
)))
{
tpVal
=
ms_deform_attn_im2col_bilinear
(
dataValuePtr
,
spatialH
,
spatialW
,
numHeads
,
channels
,
hIm
,
wIm
,
mCol
,
cCol
);
col
=
__float2half
(
__half2float
(
col
)
+
(
__half2float
(
tpVal
)
*
__half2float
(
weight
)));
}
#endif
dataWeightPtr
+=
1
;
dataLocWPtr
+=
2
;
}
}
*
dataColPtr
=
col
;
}
}
#endif
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_MS_DEFORM_ATTN_KERNEL_HPP
#define TRT_MS_DEFORM_ATTN_KERNEL_HPP
#include <cublas_v2.h>
#include <cuda_runtime.h>
template
<
typename
scalar_t
>
int32_t
ms_deform_attn_cuda_forward
(
const
scalar_t
*
value
,
const
int32_t
*
spatialShapes
,
const
int32_t
*
levelStartIndex
,
const
scalar_t
*
samplingLoc
,
const
scalar_t
*
attnWeight
,
scalar_t
*
output
,
int32_t
batch
,
int32_t
mSpatialSize
,
int32_t
mNumHeads
,
int32_t
mChannels
,
int32_t
mNumLevels
,
int32_t
mNumQuery
,
int32_t
mNumPoint
,
cudaStream_t
stream
);
#endif
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "trt_roi_align.hpp"
#include <chrono>
#include <iostream>
#include "common_cuda_helper.hpp"
#include "trt_plugin_helper.hpp"
#include "trt_roi_align_kernel.hpp"
#include "trt_serialize.hpp"
namespace
mmdeploy
{
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
PLUGIN_NAME
{
"MMCVRoiAlign"
};
}
// namespace
TRTRoIAlign
::
TRTRoIAlign
(
const
std
::
string
&
name
,
int
outWidth
,
int
outHeight
,
float
spatialScale
,
int
sampleRatio
,
int
poolMode
,
bool
aligned
)
:
TRTPluginBase
(
name
),
mOutWidth
(
outWidth
),
mOutHeight
(
outHeight
),
mSpatialScale
(
spatialScale
),
mSampleRatio
(
sampleRatio
),
mPoolMode
(
poolMode
),
mAligned
(
aligned
)
{}
TRTRoIAlign
::
TRTRoIAlign
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
TRTPluginBase
(
name
)
{
deserialize_value
(
&
data
,
&
length
,
&
mOutWidth
);
deserialize_value
(
&
data
,
&
length
,
&
mOutHeight
);
deserialize_value
(
&
data
,
&
length
,
&
mSpatialScale
);
deserialize_value
(
&
data
,
&
length
,
&
mSampleRatio
);
deserialize_value
(
&
data
,
&
length
,
&
mPoolMode
);
deserialize_value
(
&
data
,
&
length
,
&
mAligned
);
}
nvinfer1
::
IPluginV2DynamicExt
*
TRTRoIAlign
::
clone
()
const
TRT_NOEXCEPT
{
TRTRoIAlign
*
plugin
=
new
TRTRoIAlign
(
mLayerName
,
mOutWidth
,
mOutHeight
,
mSpatialScale
,
mSampleRatio
,
mPoolMode
,
mAligned
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
TRTRoIAlign
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
4
;
ret
.
d
[
0
]
=
inputs
[
1
].
d
[
0
];
ret
.
d
[
1
]
=
inputs
[
0
].
d
[
1
];
ret
.
d
[
2
]
=
exprBuilder
.
constant
(
mOutHeight
);
ret
.
d
[
3
]
=
exprBuilder
.
constant
(
mOutWidth
);
return
ret
;
}
bool
TRTRoIAlign
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
return
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
void
TRTRoIAlign
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
size_t
TRTRoIAlign
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
size_t
output_size
=
0
;
size_t
word_size
=
0
;
switch
(
mPoolMode
)
{
case
0
:
// max
output_size
=
outputs
[
0
].
dims
.
d
[
0
]
*
outputs
[
0
].
dims
.
d
[
1
]
*
outputs
[
0
].
dims
.
d
[
2
]
*
outputs
[
0
].
dims
.
d
[
3
];
word_size
=
mmdeploy
::
getElementSize
(
outputs
[
0
].
type
);
return
output_size
*
word_size
*
2
;
break
;
case
1
:
return
0
;
break
;
default:
return
0
;
}
return
0
;
}
int
TRTRoIAlign
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
int
channels
=
inputDesc
[
0
].
dims
.
d
[
1
];
int
height
=
inputDesc
[
0
].
dims
.
d
[
2
];
int
width
=
inputDesc
[
0
].
dims
.
d
[
3
];
int
output_size
=
outputDesc
[
0
].
dims
.
d
[
0
]
*
outputDesc
[
0
].
dims
.
d
[
1
]
*
outputDesc
[
0
].
dims
.
d
[
2
]
*
outputDesc
[
0
].
dims
.
d
[
3
];
int
word_size
=
mmdeploy
::
getElementSize
(
outputDesc
[
0
].
type
);
const
void
*
feat
=
inputs
[
0
];
const
void
*
rois
=
inputs
[
1
];
void
*
output
=
outputs
[
0
];
void
*
argmax_y
=
nullptr
;
void
*
argmax_x
=
nullptr
;
switch
(
mPoolMode
)
{
case
0
:
// max
argmax_y
=
workSpace
;
argmax_x
=
(
char
*
)
argmax_y
+
output_size
*
word_size
;
break
;
case
1
:
// avg
break
;
}
switch
(
outputDesc
[
0
].
type
)
{
case
nvinfer1
::
DataType
::
kFLOAT
:
TRTRoIAlignForwardCUDAKernelLauncher
<
float
>
(
(
const
float
*
)
feat
,
(
const
float
*
)
rois
,
(
float
*
)
output
,
(
float
*
)
argmax_y
,
(
float
*
)
argmax_x
,
output_size
,
channels
,
height
,
width
,
mOutHeight
,
mOutWidth
,
mSpatialScale
,
mSampleRatio
,
mPoolMode
,
mAligned
,
stream
);
break
;
default:
break
;
}
return
0
;
}
nvinfer1
::
DataType
TRTRoIAlign
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
{
return
inputTypes
[
0
];
}
// IPluginV2 Methods
const
char
*
TRTRoIAlign
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
TRTRoIAlign
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
int
TRTRoIAlign
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
size_t
TRTRoIAlign
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
serialized_size
(
mOutWidth
)
+
serialized_size
(
mOutHeight
)
+
serialized_size
(
mSpatialScale
)
+
serialized_size
(
mSampleRatio
)
+
serialized_size
(
mPoolMode
)
+
serialized_size
(
mAligned
);
}
void
TRTRoIAlign
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
serialize_value
(
&
buffer
,
mOutWidth
);
serialize_value
(
&
buffer
,
mOutHeight
);
serialize_value
(
&
buffer
,
mSpatialScale
);
serialize_value
(
&
buffer
,
mSampleRatio
);
serialize_value
(
&
buffer
,
mPoolMode
);
serialize_value
(
&
buffer
,
mAligned
);
}
TRTRoIAlignCreator
::
TRTRoIAlignCreator
()
{
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"output_height"
));
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"output_width"
));
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"spatial_scale"
));
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"sampling_ratio"
));
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"mode"
));
mPluginAttributes
.
emplace_back
(
nvinfer1
::
PluginField
(
"aligned"
));
mFC
.
nbFields
=
mPluginAttributes
.
size
();
mFC
.
fields
=
mPluginAttributes
.
data
();
}
const
char
*
TRTRoIAlignCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
TRTRoIAlignCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
nvinfer1
::
IPluginV2
*
TRTRoIAlignCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
int
outWidth
=
7
;
int
outHeight
=
7
;
float
spatialScale
=
1.0
;
int
sampleRatio
=
0
;
int
poolMode
=
-
1
;
bool
aligned
=
true
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
i
++
)
{
if
(
fc
->
fields
[
i
].
data
==
nullptr
)
{
continue
;
}
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"output_height"
)
==
0
)
{
outHeight
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
if
(
field_name
.
compare
(
"output_width"
)
==
0
)
{
outWidth
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
if
(
field_name
.
compare
(
"spatial_scale"
)
==
0
)
{
spatialScale
=
static_cast
<
const
float
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
if
(
field_name
.
compare
(
"sampling_ratio"
)
==
0
)
{
sampleRatio
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
}
if
(
field_name
.
compare
(
"mode"
)
==
0
)
{
int
data_size
=
fc
->
fields
[
i
].
length
;
ASSERT
(
data_size
>
0
);
const
char
*
data_start
=
static_cast
<
const
char
*>
(
fc
->
fields
[
i
].
data
);
std
::
string
pool_mode
(
data_start
);
if
(
pool_mode
==
"avg"
)
{
poolMode
=
1
;
}
else
if
(
pool_mode
==
"max"
)
{
poolMode
=
0
;
}
else
{
std
::
cout
<<
"Unknown pool mode
\"
"
<<
pool_mode
<<
"
\"
."
<<
std
::
endl
;
}
ASSERT
(
poolMode
>=
0
);
}
if
(
field_name
.
compare
(
"aligned"
)
==
0
)
{
int
aligned_int
=
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
)[
0
];
aligned
=
aligned_int
!=
0
;
}
}
ASSERT
(
outHeight
>
0
);
ASSERT
(
outWidth
>
0
);
ASSERT
(
spatialScale
>
0.
);
ASSERT
(
poolMode
>=
0
);
TRTRoIAlign
*
plugin
=
new
TRTRoIAlign
(
name
,
outWidth
,
outHeight
,
spatialScale
,
sampleRatio
,
poolMode
,
aligned
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
TRTRoIAlignCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
{
auto
plugin
=
new
TRTRoIAlign
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
REGISTER_TENSORRT_PLUGIN
(
TRTRoIAlignCreator
);
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_ROI_ALIGN_HPP
#define TRT_ROI_ALIGN_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace
mmdeploy
{
class
TRTRoIAlign
:
public
TRTPluginBase
{
public:
TRTRoIAlign
(
const
std
::
string
&
name
,
int
outWidth
,
int
outHeight
,
float
spatialScale
,
int
sampleRatio
,
int
poolMode
,
bool
aligned
);
TRTRoIAlign
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
TRTRoIAlign
()
=
delete
;
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
private:
int
mOutWidth
;
int
mOutHeight
;
float
mSpatialScale
;
int
mSampleRatio
;
int
mPoolMode
;
// 1:avg 0:max
bool
mAligned
;
};
class
TRTRoIAlignCreator
:
public
TRTPluginCreatorBase
{
public:
TRTRoIAlignCreator
();
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
override
;
};
}
// namespace mmdeploy
#endif // TRT_ROI_ALIGN_HPP
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "common_cuda_helper.hpp"
#include "float.h"
#include "trt_roi_align_kernel.hpp"
/*** Forward ***/
template
<
typename
T
>
__global__
void
roi_align_forward_cuda_kernel
(
const
int
nthreads
,
const
T
*
input
,
const
T
*
rois
,
T
*
output
,
T
*
argmax_y
,
T
*
argmax_x
,
const
int
pooled_height
,
const
int
pooled_width
,
const
T
spatial_scale
,
const
int
sampling_ratio
,
const
int
pool_mode
,
// 0 - max pool, 1 - avg pool
const
bool
aligned
,
const
int
channels
,
const
int
height
,
const
int
width
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, c, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
int
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
int
c
=
(
index
/
pooled_width
/
pooled_height
)
%
channels
;
int
n
=
index
/
pooled_width
/
pooled_height
/
channels
;
const
T
*
offset_rois
=
rois
+
n
*
5
;
int
roi_batch_ind
=
offset_rois
[
0
];
// Do not using rounding; this implementation detail is critical
T
offset
=
aligned
?
(
T
)
0.5
:
(
T
)
0.0
;
T
roi_start_w
=
offset_rois
[
1
]
*
spatial_scale
-
offset
;
T
roi_start_h
=
offset_rois
[
2
]
*
spatial_scale
-
offset
;
T
roi_end_w
=
offset_rois
[
3
]
*
spatial_scale
-
offset
;
T
roi_end_h
=
offset_rois
[
4
]
*
spatial_scale
-
offset
;
T
roi_width
=
roi_end_w
-
roi_start_w
;
T
roi_height
=
roi_end_h
-
roi_start_h
;
if
(
!
aligned
)
{
// for backward-compatibility only
roi_width
=
max
(
roi_width
,
(
T
)
1.
);
roi_height
=
max
(
roi_height
,
(
T
)
1.
);
}
T
bin_size_h
=
static_cast
<
T
>
(
roi_height
)
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
static_cast
<
T
>
(
roi_width
)
/
static_cast
<
T
>
(
pooled_width
);
const
T
*
offset_input
=
input
+
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
// We use roi_bin_grid to sample the grid and mimic integral
int
roi_bin_grid_h
=
(
sampling_ratio
>
0
)
?
sampling_ratio
:
static_cast
<
int
>
(
ceilf
(
roi_height
/
pooled_height
));
int
roi_bin_grid_w
=
(
sampling_ratio
>
0
)
?
sampling_ratio
:
static_cast
<
int
>
(
ceilf
(
roi_width
/
pooled_width
));
if
(
pool_mode
==
0
)
{
// We do max pooling inside a bin
T
maxval
=
-
FLT_MAX
;
T
maxidx_y
=
-
1.
f
,
maxidx_x
=
-
1.
f
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
const
T
y
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
T
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
T
>
(
roi_bin_grid_h
);
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
T
x
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
T
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
T
>
(
roi_bin_grid_w
);
T
val
=
bilinear_interpolate
(
offset_input
,
height
,
width
,
y
,
x
);
if
(
val
>
maxval
)
{
maxval
=
val
;
maxidx_y
=
y
;
maxidx_x
=
x
;
}
}
}
output
[
index
]
=
maxval
;
argmax_y
[
index
]
=
maxidx_y
;
argmax_x
[
index
]
=
maxidx_x
;
}
else
if
(
pool_mode
==
1
)
{
// We do average pooling inside a bin
const
T
count
=
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
);
T
output_val
=
0.
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
const
T
y
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
T
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
T
>
(
roi_bin_grid_h
);
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
T
x
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
T
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
T
>
(
roi_bin_grid_w
);
T
val
=
bilinear_interpolate
(
offset_input
,
height
,
width
,
y
,
x
);
output_val
+=
val
;
}
}
output
[
index
]
=
output_val
/
count
;
}
}
}
template
<
typename
scalar_t
>
void
TRTRoIAlignForwardCUDAKernelLauncher
(
const
scalar_t
*
input
,
const
scalar_t
*
rois
,
scalar_t
*
output
,
scalar_t
*
argmax_y
,
scalar_t
*
argmax_x
,
int
output_size
,
int
channels
,
int
height
,
int
width
,
int
aligned_height
,
int
aligned_width
,
scalar_t
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
,
cudaStream_t
stream
)
{
roi_align_forward_cuda_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
output_size
),
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
output_size
,
input
,
rois
,
output
,
argmax_y
,
argmax_x
,
aligned_height
,
aligned_width
,
static_cast
<
scalar_t
>
(
spatial_scale
),
sampling_ratio
,
pool_mode
,
aligned
,
channels
,
height
,
width
);
}
template
void
TRTRoIAlignForwardCUDAKernelLauncher
<
float
>(
const
float
*
input
,
const
float
*
rois
,
float
*
output
,
float
*
argmax_y
,
float
*
argmax_x
,
int
output_size
,
int
channels
,
int
height
,
int
width
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
,
cudaStream_t
stream
);
csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef ROI_ALIGN_CUDA_KERNEL_HPP
#define ROI_ALIGN_CUDA_KERNEL_HPP
#include "common_cuda_helper.hpp"
template
<
typename
scalar_t
>
void
TRTRoIAlignForwardCUDAKernelLauncher
(
const
scalar_t
*
input
,
const
scalar_t
*
rois
,
scalar_t
*
output
,
scalar_t
*
argmax_y
,
scalar_t
*
argmax_x
,
int
output_size
,
int
channels
,
int
height
,
int
width
,
int
aligned_height
,
int
aligned_width
,
scalar_t
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
,
cudaStream_t
stream
);
#endif // ROI_ALIGN_CUDA_KERNEL_HPP
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved
#include "scaled_dot_product_attention.hpp"
#include <assert.h>
#include <chrono>
#include "scaled_dot_product_attention_kernel.hpp"
#include "trt_serialize.hpp"
using
namespace
nvinfer1
;
namespace
mmdeploy
{
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
PLUGIN_NAME
{
"ScaledDotProductAttentionTRT"
};
}
// namespace
ScaledDotProductAttentionTRT
::
ScaledDotProductAttentionTRT
(
const
std
::
string
&
name
)
:
TRTPluginBase
(
name
),
mask_dim
(
0
)
{}
ScaledDotProductAttentionTRT
::
ScaledDotProductAttentionTRT
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
TRTPluginBase
(
name
),
mask_dim
(
0
)
{}
ScaledDotProductAttentionTRT
::~
ScaledDotProductAttentionTRT
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
ScaledDotProductAttentionTRT
::
clone
()
const
TRT_NOEXCEPT
{
ScaledDotProductAttentionTRT
*
plugin
=
new
ScaledDotProductAttentionTRT
(
mLayerName
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
ScaledDotProductAttentionTRT
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
if
(
outputIndex
==
0
)
return
inputs
[
0
];
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
3
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
inputs
[
0
].
d
[
1
];
ret
.
d
[
2
]
=
inputs
[
1
].
d
[
1
];
return
ret
;
}
bool
ScaledDotProductAttentionTRT
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
if
(
pos
==
0
)
{
return
(
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
ioDesc
[
pos
].
type
==
ioDesc
[
0
].
type
&&
ioDesc
[
pos
].
format
==
ioDesc
[
0
].
format
;
}
}
// Attach the plugin object to an execution context and grant the plugin the
// access to some context resource.
void
ScaledDotProductAttentionTRT
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{
_cublas_handle
=
cublasContext
;
_cudnn_handle
=
cudnnContext
;
cudnnCreateTensorDescriptor
(
&
_x_desc
);
cudnnCreateTensorDescriptor
(
&
_y_desc
);
cudnnCreateTensorDescriptor
(
&
_mask_desc
);
}
// Detach the plugin object from its execution context.
void
ScaledDotProductAttentionTRT
::
detachFromContext
()
TRT_NOEXCEPT
{
cudnnDestroyTensorDescriptor
(
_y_desc
);
cudnnDestroyTensorDescriptor
(
_x_desc
);
cudnnDestroyTensorDescriptor
(
_mask_desc
);
}
void
ScaledDotProductAttentionTRT
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{
if
(
nbInputs
!=
4
)
{
mask_dim
=
0
;
}
else
{
mask_dim
=
in
[
3
].
desc
.
dims
.
nbDims
;
}
}
int
ScaledDotProductAttentionTRT
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
if
(
CUDNN_STATUS_SUCCESS
!=
cudnnSetStream
(
_cudnn_handle
,
stream
))
return
1
;
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasSetStream
(
_cublas_handle
,
stream
))
return
1
;
int
B
=
inputDesc
[
0
].
dims
.
d
[
0
];
// batch * heads
int
Nt
=
inputDesc
[
0
].
dims
.
d
[
1
];
int
Ns
=
inputDesc
[
1
].
dims
.
d
[
1
];
int
E
=
inputDesc
[
0
].
dims
.
d
[
2
];
// embeding size
const
void
*
query
=
inputs
[
0
];
const
void
*
key
=
inputs
[
1
];
const
void
*
value
=
inputs
[
2
];
const
void
*
mask
=
nullptr
;
int
mask_dims
[
3
];
mask_dims
[
0
]
=
0
;
if
(
mask_dim
>
0
)
{
mask
=
inputs
[
3
];
// check if mask need broadcast
if
(
mask_dim
==
2
)
{
mask_dims
[
0
]
=
1
;
mask_dims
[
1
]
=
inputDesc
[
3
].
dims
.
d
[
0
];
mask_dims
[
2
]
=
inputDesc
[
3
].
dims
.
d
[
1
];
}
else
{
mask_dims
[
0
]
=
inputDesc
[
3
].
dims
.
d
[
0
];
mask_dims
[
1
]
=
inputDesc
[
3
].
dims
.
d
[
1
];
mask_dims
[
2
]
=
inputDesc
[
3
].
dims
.
d
[
2
];
}
}
void
*
output
=
outputs
[
0
];
void
*
attn
=
outputs
[
1
];
auto
data_type
=
inputDesc
[
0
].
type
;
cudnnDataType_t
cudnn_dtype
{};
convert_trt2cudnn_dtype
(
data_type
,
&
cudnn_dtype
);
switch
(
data_type
)
{
case
nvinfer1
::
DataType
::
kFLOAT
:
dot_product_attention_impl
<
float
>
((
float
*
)
query
,
(
float
*
)
key
,
(
float
*
)
value
,
(
float
*
)
mask
,
(
float
*
)
attn
,
(
float
*
)
output
,
B
,
Nt
,
Ns
,
E
,
&
mask_dims
[
0
],
_x_desc
,
_y_desc
,
_mask_desc
,
cudnn_dtype
,
stream
,
_cublas_handle
,
_cudnn_handle
);
break
;
default:
return
1
;
}
return
0
;
}
nvinfer1
::
DataType
ScaledDotProductAttentionTRT
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
{
return
inputTypes
[
0
];
}
// IPluginV2 Methods
const
char
*
ScaledDotProductAttentionTRT
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
ScaledDotProductAttentionTRT
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
int
ScaledDotProductAttentionTRT
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
2
;
}
size_t
ScaledDotProductAttentionTRT
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
0
;
}
void
ScaledDotProductAttentionTRT
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{}
////////////////////// creator /////////////////////////////
ScaledDotProductAttentionTRTCreator
::
ScaledDotProductAttentionTRTCreator
()
{}
const
char
*
ScaledDotProductAttentionTRTCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
ScaledDotProductAttentionTRTCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
nvinfer1
::
IPluginV2
*
ScaledDotProductAttentionTRTCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
ScaledDotProductAttentionTRT
*
plugin
=
new
ScaledDotProductAttentionTRT
(
name
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
ScaledDotProductAttentionTRTCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
{
auto
plugin
=
new
ScaledDotProductAttentionTRT
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
REGISTER_TENSORRT_PLUGIN
(
ScaledDotProductAttentionTRTCreator
);
}
// namespace mmdeploy
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace
mmdeploy
{
class
ScaledDotProductAttentionTRT
:
public
TRTPluginBase
{
public:
ScaledDotProductAttentionTRT
(
const
std
::
string
&
name
);
ScaledDotProductAttentionTRT
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
ScaledDotProductAttentionTRT
()
=
delete
;
~
ScaledDotProductAttentionTRT
()
TRT_NOEXCEPT
override
;
virtual
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
void
attachToContext
(
cudnnContext
*
cudnn
,
cublasContext
*
cublas
,
nvinfer1
::
IGpuAllocator
*
allocator
)
TRT_NOEXCEPT
override
;
void
detachFromContext
()
TRT_NOEXCEPT
override
;
private:
int
mask_dim
;
cublasHandle_t
_cublas_handle
{};
cudnnHandle_t
_cudnn_handle
{};
cudnnTensorDescriptor_t
_x_desc
{},
_y_desc
{},
_mask_desc
{};
};
class
ScaledDotProductAttentionTRTCreator
:
public
TRTPluginCreatorBase
{
public:
ScaledDotProductAttentionTRTCreator
();
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
override
;
};
}
// namespace mmdeploy
#endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>
#include <cmath>
#include <vector>
#include "common_cuda_helper.hpp"
#include "scaled_dot_product_attention_kernel.hpp"
#include "trt_plugin_helper.hpp"
template
<
typename
scalar_t
>
cublasStatus_t
cublasgemmStridedBatchedWrap
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
scalar_t
*
alpha
,
const
scalar_t
*
A
,
int
lda
,
long
long
int
strideA
,
const
scalar_t
*
B
,
int
ldb
,
long
long
int
strideB
,
const
scalar_t
*
beta
,
scalar_t
*
C
,
int
ldc
,
long
long
int
strideC
,
int
batchCount
);
template
<
>
cublasStatus_t
cublasgemmStridedBatchedWrap
<
float
>
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
A
,
int
lda
,
long
long
int
strideA
,
const
float
*
B
,
int
ldb
,
long
long
int
strideB
,
const
float
*
beta
,
float
*
C
,
int
ldc
,
long
long
int
strideC
,
int
batchCount
)
{
return
cublasSgemmStridedBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
strideA
,
B
,
ldb
,
strideB
,
beta
,
C
,
ldc
,
strideC
,
batchCount
);
}
template
<
>
cublasStatus_t
cublasgemmStridedBatchedWrap
<
__half
>
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
A
,
int
lda
,
long
long
int
strideA
,
const
__half
*
B
,
int
ldb
,
long
long
int
strideB
,
const
__half
*
beta
,
__half
*
C
,
int
ldc
,
long
long
int
strideC
,
int
batchCount
)
{
return
cublasHgemmStridedBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
strideA
,
B
,
ldb
,
strideB
,
beta
,
C
,
ldc
,
strideC
,
batchCount
);
}
template
<
typename
scalar_t
>
void
dot_product_attention_impl
(
const
scalar_t
*
query
,
const
scalar_t
*
key
,
const
scalar_t
*
value
,
const
scalar_t
*
mask
,
scalar_t
*
attn
,
scalar_t
*
output
,
int
B
,
int
Nt
,
int
Ns
,
int
E
,
const
int
*
mask_dims
,
cudnnTensorDescriptor_t
&
x_desc
,
cudnnTensorDescriptor_t
&
y_desc
,
cudnnTensorDescriptor_t
&
mask_desc
,
cudnnDataType_t
cudnn_dtype
,
cudaStream_t
stream
,
cublasHandle_t
cublas_handle
,
cudnnHandle_t
cudnn_handle
)
{
{
// Q @ K
const
int
m
=
Ns
;
const
int
n
=
Nt
;
const
int
k
=
E
;
const
auto
alpha
=
scalar_t
(
1.0
f
/
sqrt
(
float
(
E
)));
const
auto
beta
=
scalar_t
(
0
);
cublasgemmStridedBatchedWrap
(
cublas_handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
m
,
n
,
k
,
&
alpha
,
key
,
k
,
Ns
*
E
,
query
,
k
,
Nt
*
E
,
&
beta
,
attn
,
m
,
Nt
*
Ns
,
B
);
}
if
(
mask_dims
!=
nullptr
&&
mask_dims
[
0
]
!=
0
)
{
const
auto
alpha
=
scalar_t
(
1
);
const
auto
beta
=
scalar_t
(
1
);
cudnnSetTensor4dDescriptor
(
mask_desc
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
1
,
mask_dims
[
0
],
mask_dims
[
1
],
mask_dims
[
2
]);
cudnnSetTensor4dDescriptor
(
x_desc
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
1
,
B
,
Nt
,
Ns
);
cudnnAddTensor
(
cudnn_handle
,
&
alpha
,
mask_desc
,
mask
,
&
beta
,
x_desc
,
attn
);
}
{
// softmax attention
const
auto
alpha
=
scalar_t
(
1
);
const
auto
beta
=
scalar_t
(
0
);
cudnnSetTensor4dDescriptor
(
x_desc
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
B
*
Nt
,
Ns
,
1
,
1
);
cudnnSetTensor4dDescriptor
(
y_desc
,
CUDNN_TENSOR_NCHW
,
cudnn_dtype
,
B
*
Nt
,
Ns
,
1
,
1
);
cudnnSoftmaxForward
(
cudnn_handle
,
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_MODE_INSTANCE
,
&
alpha
,
x_desc
,
attn
,
&
beta
,
y_desc
,
attn
);
}
{
// attn @ v
const
int
m
=
E
;
const
int
n
=
Nt
;
const
int
k
=
Ns
;
const
auto
alpha
=
scalar_t
(
1
);
const
auto
beta
=
scalar_t
(
0
);
cublasgemmStridedBatchedWrap
(
cublas_handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
m
,
n
,
k
,
&
alpha
,
value
,
m
,
Ns
*
E
,
(
const
scalar_t
*
)(
attn
),
k
,
Ns
*
Nt
,
&
beta
,
output
,
m
,
Nt
*
E
,
B
);
}
}
template
void
dot_product_attention_impl
<
float
>(
const
float
*
query
,
const
float
*
key
,
const
float
*
value
,
const
float
*
mask
,
float
*
attn
,
float
*
output
,
int
B
,
int
Nt
,
int
Ns
,
int
E
,
const
int
*
mask_dims
,
cudnnTensorDescriptor_t
&
x_desc
,
cudnnTensorDescriptor_t
&
y_desc
,
cudnnTensorDescriptor_t
&
mask_desc
,
cudnnDataType_t
cudnn_dtype
,
cudaStream_t
stream
,
cublasHandle_t
cublas_handle
,
cudnnHandle_t
cudnn_handle
);
csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
#define TRT_SCALED_DOT_PRODUCT_ATTENTION_KERNEL_HPP
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cudnn.h>
template
<
typename
scalar_t
>
void
dot_product_attention_impl
(
const
scalar_t
*
query
,
const
scalar_t
*
key
,
const
scalar_t
*
value
,
const
scalar_t
*
mask
,
scalar_t
*
attn
,
scalar_t
*
output
,
int
B
,
int
Nt
,
int
Ns
,
int
E
,
const
int
*
mask_dims
,
cudnnTensorDescriptor_t
&
x_desc
,
cudnnTensorDescriptor_t
&
y_desc
,
cudnnTensorDescriptor_t
&
mask_desc
,
cudnnDataType_t
cudnn_dtype
,
cudaStream_t
stream
,
cublasHandle_t
cublas_handle
,
cudnnHandle_t
cudnn_handle
);
#endif
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#include "NvInferVersion.h"
// ScatterND is supported since TensorRT8
#if NV_TENSORRT_MAJOR <= 7
#include <assert.h>
#include <stdio.h>
#include <chrono>
#include "trt_scatternd.hpp"
#include "trt_scatternd_kernel.hpp"
#include "trt_serialize.hpp"
namespace
mmdeploy
{
namespace
{
static
const
char
*
PLUGIN_VERSION
{
"1"
};
static
const
char
*
PLUGIN_NAME
{
"ScatterND"
};
}
// namespace
TRTScatterND
::
TRTScatterND
(
const
std
::
string
&
name
)
:
TRTPluginBase
(
name
)
{}
TRTScatterND
::
TRTScatterND
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
)
:
TRTPluginBase
(
name
)
{}
nvinfer1
::
IPluginV2DynamicExt
*
TRTScatterND
::
clone
()
const
TRT_NOEXCEPT
{
TRTScatterND
*
plugin
=
new
TRTScatterND
(
mLayerName
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
DimsExprs
TRTScatterND
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
return
inputs
[
0
];
}
bool
TRTScatterND
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
if
(
pos
<
nbInputs
)
{
switch
(
pos
)
{
case
0
:
// data
return
(
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
)
||
(
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
case
1
:
// indices
return
ioDesc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
ioDesc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
case
2
:
// updates
return
ioDesc
[
pos
].
type
==
ioDesc
[
0
].
type
&&
ioDesc
[
pos
].
format
==
ioDesc
[
0
].
format
;
default:
return
true
;
}
}
else
{
switch
(
pos
-
nbInputs
)
{
case
0
:
// output
return
ioDesc
[
pos
].
type
==
ioDesc
[
0
].
type
&&
ioDesc
[
pos
].
format
==
ioDesc
[
0
].
format
;
default:
return
true
;
}
}
return
true
;
}
void
TRTScatterND
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
size_t
TRTScatterND
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
int
TRTScatterND
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workSpace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
int
*
dims
=
&
(
inputDesc
[
0
].
dims
.
d
[
0
]);
const
int
*
indices_dims
=
&
(
inputDesc
[
1
].
dims
.
d
[
0
]);
int
nbDims
=
inputDesc
[
0
].
dims
.
nbDims
;
int
indice_nbDims
=
inputDesc
[
1
].
dims
.
nbDims
;
const
void
*
data
=
inputs
[
0
];
const
void
*
indices
=
inputs
[
1
];
const
void
*
update
=
inputs
[
2
];
void
*
output
=
outputs
[
0
];
auto
data_type
=
inputDesc
[
0
].
type
;
switch
(
data_type
)
{
case
nvinfer1
::
DataType
::
kFLOAT
:
TRTONNXScatterNDKernelLauncher
<
float
>
((
float
*
)
data
,
(
int
*
)
indices
,
(
float
*
)
update
,
dims
,
nbDims
,
indices_dims
,
indice_nbDims
,
(
float
*
)
output
,
stream
);
break
;
case
nvinfer1
::
DataType
::
kINT32
:
TRTONNXScatterNDKernelLauncher
<
int
>
((
int
*
)
data
,
(
int
*
)
indices
,
(
int
*
)
update
,
dims
,
nbDims
,
indices_dims
,
indice_nbDims
,
(
int
*
)
output
,
stream
);
break
;
default:
break
;
}
return
0
;
}
nvinfer1
::
DataType
TRTScatterND
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
{
return
inputTypes
[
0
];
}
// IPluginV2 Methods
const
char
*
TRTScatterND
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
TRTScatterND
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
int
TRTScatterND
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
size_t
TRTScatterND
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
0
;
}
void
TRTScatterND
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{}
TRTScatterNDCreator
::
TRTScatterNDCreator
()
{
mPluginAttributes
.
clear
();
mFC
.
nbFields
=
mPluginAttributes
.
size
();
mFC
.
fields
=
mPluginAttributes
.
data
();
}
const
char
*
TRTScatterNDCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
PLUGIN_NAME
;
}
const
char
*
TRTScatterNDCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
PLUGIN_VERSION
;
}
nvinfer1
::
IPluginV2
*
TRTScatterNDCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
TRTScatterND
*
plugin
=
new
TRTScatterND
(
name
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
nvinfer1
::
IPluginV2
*
TRTScatterNDCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
{
auto
plugin
=
new
TRTScatterND
(
name
,
serialData
,
serialLength
);
plugin
->
setPluginNamespace
(
getPluginNamespace
());
return
plugin
;
}
REGISTER_TENSORRT_PLUGIN
(
TRTScatterNDCreator
);
}
// namespace mmdeploy
#endif
csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp
0 → 100644
View file @
546b4279
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCATTERND_HPP
#define TRT_SCATTERND_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace
mmdeploy
{
class
TRTScatterND
:
public
TRTPluginBase
{
public:
TRTScatterND
(
const
std
::
string
&
name
);
TRTScatterND
(
const
std
::
string
name
,
const
void
*
data
,
size_t
length
);
TRTScatterND
()
=
delete
;
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
ioDesc
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
// IPluginV2Ext Methods
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
// IPluginV2 Methods
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
};
class
TRTScatterNDCreator
:
public
TRTPluginCreatorBase
{
public:
TRTScatterNDCreator
();
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
override
;
};
}
// namespace mmdeploy
#endif // TRT_SCATTERND_HPP
Prev
1
…
9
10
11
12
13
14
15
16
17
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment