Unverified Commit be2616a0 authored by tangyanf's avatar tangyanf Committed by GitHub
Browse files

add ort nms op (#803)

* add ort nms op

* fit lint check

* fix lint check

* update code

* fix lint check

* update code

* update code

* update code

* update code

* update code

* update code
parent 6c57b88f
#include "nms.h"
#include <assert.h>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <iterator>
#include <vector>
#include "../ort_mmcv_utils.h"
NmsKernel::NmsKernel(OrtApi api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
offset_ = ort_.KernelInfoGetAttribute<int64_t>(info, "offset");
// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}
void NmsKernel::Compute(OrtKernelContext *context) {
const float iou_threshold = iou_threshold_;
const int64_t offset = offset_;
const OrtValue *boxes = ort_.KernelContext_GetInput(context, 0);
const float *boxes_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(boxes));
const OrtValue *scores = ort_.KernelContext_GetInput(context, 1);
const float *scores_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(scores));
OrtTensorDimensions boxes_dim(ort_, boxes);
OrtTensorDimensions scores_dim(ort_, scores);
int64_t nboxes = boxes_dim[0];
assert(boxes_dim[1] == 4);
// allocate tmp memory
float *tmp_boxes = (float *)allocator_.Alloc(sizeof(float) * nboxes * 4);
float *sc = (float *)allocator_.Alloc(sizeof(float) * nboxes);
float *areas = (float *)allocator_.Alloc(sizeof(float) * nboxes);
bool *select = (bool *)allocator_.Alloc(sizeof(bool) * nboxes);
for (int64_t i = 0; i < nboxes; i++) {
select[i] = true;
}
memcpy(tmp_boxes, boxes_data, sizeof(float) * nboxes * 4);
memcpy(sc, scores_data, sizeof(float) * nboxes);
// sort scores
std::vector<float> tmp_sc;
for (int i = 0; i < nboxes; i++) {
tmp_sc.push_back(sc[i]);
}
std::vector<int64_t> order(tmp_sc.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [&tmp_sc](int64_t id1, int64_t id2) {
return tmp_sc[id1] > tmp_sc[id2];
});
// area = (x2 - x1 + offset) * (y2 - y1 + offset)
for (int64_t i = 0; i < nboxes; i++) {
areas[i] = (tmp_boxes[i * 4 + 2] - tmp_boxes[i * 4 + 0] + offset) *
(tmp_boxes[i * 4 + 3] - tmp_boxes[i * 4 + 1] + offset);
}
for (int64_t _i = 0; _i < nboxes; _i++) {
if (select[_i] == false) continue;
auto i = order[_i];
auto ix1 = tmp_boxes[i * 4 + 0];
auto iy1 = tmp_boxes[i * 4 + 1];
auto ix2 = tmp_boxes[i * 4 + 2];
auto iy2 = tmp_boxes[i * 4 + 3];
auto iarea = areas[i];
for (int64_t _j = _i + 1; _j < nboxes; _j++) {
if (select[_j] == false) continue;
auto j = order[_j];
auto xx1 = std::max(ix1, tmp_boxes[j * 4 + 0]);
auto yy1 = std::max(iy1, tmp_boxes[j * 4 + 1]);
auto xx2 = std::min(ix2, tmp_boxes[j * 4 + 2]);
auto yy2 = std::min(iy2, tmp_boxes[j * 4 + 3]);
auto w = std::max(0.f, xx2 - xx1 + offset);
auto h = std::max(0.f, yy2 - yy1 + offset);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr >= iou_threshold) select[_j] = false;
}
}
std::vector<int64_t> res_order;
for (int i = 0; i < nboxes; i++) {
if (select[i]) {
res_order.push_back(order[i]);
}
}
std::vector<int64_t> inds_dims({res_order.size()});
OrtValue *res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(),
inds_dims.size());
int64_t *res_data = ort_.GetTensorMutableData<int64_t>(res);
memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size());
}
#include "onnxruntime_register.h"
#include "nms.h"
#include "ort_mmcv_utils.h"
#include "roi_align.h"
#include "soft_nms.h"
const char *c_MMCVOpDomain = "mmcv";
SoftNmsOp c_SoftNmsOp;
NmsOp c_NmsOp;
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
......@@ -21,6 +23,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_NmsOp)) {
return status;
}
if (auto status =
ortApi->CustomOpDomain_Add(domain, &c_MMCVRoiAlignCustomOp)) {
return status;
......
#ifndef ONNXRUNTIME_NMS_H
#define ONNXRUNTIME_NMS_H
#include <onnxruntime_cxx_api.h>
struct NmsKernel {
NmsKernel(OrtApi api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
float iou_threshold_;
int64_t offset_;
};
struct NmsOp : Ort::CustomOpBase<NmsOp, NmsKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
return new NmsKernel(api, info);
};
const char *GetName() const { return "NonMaxSuppression"; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
// force cpu
const char *GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};
#endif
......@@ -21,22 +21,34 @@ class NMSop(torch.autograd.Function):
@staticmethod
def symbolic(g, bboxes, scores, iou_threshold, offset):
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op(
'Constant', value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores,
max_output_per_class, iou_threshold)
return squeeze(
g,
select(
g, nms_out, 1,
g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))),
1)
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded()
if has_custom_op:
return g.op(
'mmcv::NonMaxSuppression',
bboxes,
scores,
iou_threshold_f=float(iou_threshold),
offset_i=int(offset))
else:
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op(
'Constant',
value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores,
max_output_per_class, iou_threshold)
return squeeze(
g,
select(
g, nms_out, 1,
g.op(
'Constant',
value_t=torch.tensor([2], dtype=torch.long))), 1)
class SoftNMSop(torch.autograd.Function):
......@@ -142,13 +154,7 @@ def nms(boxes, scores, iou_threshold, offset=0):
select = ext_module.nms(*indata_list, **indata_dict)
inds = order.masked_select(select)
else:
if torch.onnx.is_in_onnx_export() and offset == 0:
# ONNX only support offset == 1
boxes[:, -2:] -= 1
inds = NMSop.apply(boxes, scores, iou_threshold, offset)
if torch.onnx.is_in_onnx_export() and offset == 0:
# ONNX only support offset == 1
boxes[:, -2:] += 1
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy:
dets = dets.cpu().numpy()
......
......@@ -24,7 +24,7 @@ class WrapFunction(nn.Module):
def test_nms():
from mmcv.ops import nms
from mmcv.ops import get_onnxruntime_op_path, nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
......@@ -46,12 +46,19 @@ def test_nms():
opset_version=11)
onnx_model = onnx.load(onnx_file)
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('nms for onnxruntime is not compiled.')
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment