"tests/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "d217278ccbd1554c228f6b208e1fa3de608b44ae"
Unverified Commit bf2c9fa8 authored by SemyonBevzuk's avatar SemyonBevzuk Committed by GitHub
Browse files

[Feature] NMS update (#957)

* Add score_threshold and max_num to NMS

* Fix codestyle

* Fix codestyle

* Fix inds in nms

* Update nms docstring

* Move score_threshold and max_num arguments

* Fix args order in docstring

* fix lint of c++ file

* Remove torch.onnx.is_in_onnx_export() and add max_num to batched_nms for separate classes.

* Rewrote max_num handling in NMSop.symbolic

* Added processing max_output_boxes_per_class when exporting to TensorRT

* Added score_threshold and max_num for NMS in test_onnx.py and test_tensorrt.py

* Remove _is_value(max_num)

* fix ci errors with torch==1.3.1

* Update test_batched_nms in test_nms.py

* Added tests for preprocess_onnx

* Moved 'test_tensorrt_preprocess.py' and 'preprocess', updated 'remove_tmp_file'.

* Update mmcv/tensorrt/__init__.py

* Fix segfault torch==1.3.1 (remove onnx.checker.check_model)

* Returned 'onnx.checker.check_model' with torch version check

* Changed torch version from 1.3.1 to 1.4.0

*...
parent 717d1571
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
| [SoftNMS](onnxruntime_custom_ops.md#softnms) | Y | N | 1.2.3 | | [SoftNMS](onnxruntime_custom_ops.md#softnms) | Y | N | 1.2.3 |
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 | | [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 | | [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master | | [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | 1.3.1 |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master | | [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | 1.3.4 |
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master | | [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master | | [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |
## How to build custom operators for ONNX Runtime ## How to build custom operators for ONNX Runtime
......
...@@ -34,6 +34,7 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u ...@@ -34,6 +34,7 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master | | cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master | | cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master | | MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |
Notes Notes
- All plugins listed above are developed on TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0 - All plugins listed above are developed on TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0
......
import os import os
import sys
import numpy as np import numpy as np
import torch import torch
...@@ -15,13 +14,27 @@ ext_module = ext_loader.load_ext( ...@@ -15,13 +14,27 @@ ext_module = ext_loader.load_ext(
class NMSop(torch.autograd.Function): class NMSop(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, bboxes, scores, iou_threshold, offset): def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
is_filtering_by_score = score_threshold > 0
if is_filtering_by_score:
valid_mask = scores > score_threshold
bboxes, scores = bboxes[valid_mask], scores[valid_mask]
valid_inds = torch.nonzero(
valid_mask, as_tuple=False).squeeze(dim=1)
inds = ext_module.nms( inds = ext_module.nms(
bboxes, scores, iou_threshold=float(iou_threshold), offset=offset) bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
if max_num > 0:
inds = inds[:max_num]
if is_filtering_by_score:
inds = valid_inds[inds]
return inds return inds
@staticmethod @staticmethod
def symbolic(g, bboxes, scores, iou_threshold, offset): def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
from ..onnx import is_custom_op_loaded from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded() has_custom_op = is_custom_op_loaded()
# TensorRT nms plugin is aligned with original nms in ONNXRuntime # TensorRT nms plugin is aligned with original nms in ONNXRuntime
...@@ -35,16 +48,28 @@ class NMSop(torch.autograd.Function): ...@@ -35,16 +48,28 @@ class NMSop(torch.autograd.Function):
offset_i=int(offset)) offset_i=int(offset))
else: else:
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
from ..onnx.onnx_utils.symbolic_helper import _size_helper
boxes = unsqueeze(g, bboxes, 0) boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op(
'Constant', if max_num > 0:
value_t=torch.tensor([sys.maxsize], dtype=torch.long)) max_num = g.op(
'Constant',
value_t=torch.tensor(max_num, dtype=torch.long))
else:
dim = g.op('Constant', value_t=torch.tensor(0))
max_num = _size_helper(g, bboxes, dim)
max_output_per_class = max_num
iou_threshold = g.op( iou_threshold = g.op(
'Constant', 'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float)) value_t=torch.tensor([iou_threshold], dtype=torch.float))
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores, nms_out = g.op('NonMaxSuppression', boxes, scores,
max_output_per_class, iou_threshold) max_output_per_class, iou_threshold,
score_threshold)
return squeeze( return squeeze(
g, g,
select( select(
...@@ -90,7 +115,7 @@ class SoftNMSop(torch.autograd.Function): ...@@ -90,7 +115,7 @@ class SoftNMSop(torch.autograd.Function):
@deprecated_api_warning({'iou_thr': 'iou_threshold'}) @deprecated_api_warning({'iou_thr': 'iou_threshold'})
def nms(boxes, scores, iou_threshold, offset=0): def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
"""Dispatch to either CPU or GPU NMS implementations. """Dispatch to either CPU or GPU NMS implementations.
The input can be either torch tensor or numpy array. GPU NMS will be used The input can be either torch tensor or numpy array. GPU NMS will be used
...@@ -102,6 +127,8 @@ def nms(boxes, scores, iou_threshold, offset=0): ...@@ -102,6 +127,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
scores (torch.Tensor or np.ndarray): scores in shape (N, ). scores (torch.Tensor or np.ndarray): scores in shape (N, ).
iou_threshold (float): IoU threshold for NMS. iou_threshold (float): IoU threshold for NMS.
offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
score_threshold (float): score threshold for NMS.
max_num (int): maximum number of boxes after NMS.
Returns: Returns:
tuple: kept dets(boxes and scores) and indice, which is always the \ tuple: kept dets(boxes and scores) and indice, which is always the \
...@@ -141,7 +168,8 @@ def nms(boxes, scores, iou_threshold, offset=0): ...@@ -141,7 +168,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
} }
inds = ext_module.nms(*indata_list, **indata_dict) inds = ext_module.nms(*indata_list, **indata_dict)
else: else:
inds = NMSop.apply(boxes, scores, iou_threshold, offset) inds = NMSop.apply(boxes, scores, iou_threshold, offset,
score_threshold, max_num)
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy: if is_numpy:
dets = dets.cpu().numpy() dets = dets.cpu().numpy()
...@@ -285,6 +313,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -285,6 +313,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
# Some type of nms would reweight the score, such as SoftNMS # Some type of nms would reweight the score, such as SoftNMS
scores = dets[:, 4] scores = dets[:, 4]
else: else:
max_num = nms_cfg_.pop('max_num', -1)
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
# Some type of nms would reweight the score, such as SoftNMS # Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores.new_zeros(scores.size()) scores_after_nms = scores.new_zeros(scores.size())
...@@ -294,10 +323,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): ...@@ -294,10 +323,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
total_mask[mask[keep]] = True total_mask[mask[keep]] = True
scores_after_nms[mask[keep]] = dets[:, -1] scores_after_nms[mask[keep]] = dets[:, -1]
keep = total_mask.nonzero(as_tuple=False).view(-1) keep = total_mask.nonzero(as_tuple=False).view(-1)
scores, inds = scores_after_nms[keep].sort(descending=True) scores, inds = scores_after_nms[keep].sort(descending=True)
keep = keep[inds] keep = keep[inds]
boxes = boxes[keep] boxes = boxes[keep]
if max_num > 0:
keep = keep[:max_num]
boxes = boxes[:max_num]
scores = scores[:max_num]
return torch.cat([boxes, scores[:, None]], -1), keep return torch.cat([boxes, scores[:, None]], -1), keep
......
# flake8: noqa # flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine, onnx2trt, from .preprocess import preprocess_onnx
save_trt_engine)
# load tensorrt plugin lib
load_tensorrt_plugin()
__all__ = [ def is_tensorrt_available():
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper', try:
'TRTWraper', 'is_tensorrt_plugin_loaded' import tensorrt
] del tensorrt
return True
except ModuleNotFoundError:
return False
__all__ = []
if is_tensorrt_available():
from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine,
onnx2trt, save_trt_engine)
# load tensorrt plugin lib
load_tensorrt_plugin()
__all__.append([
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'TRTWrapper'
])
__all__.append(['is_tensorrt_plugin_loaded', 'preprocess_onnx'])
import numpy as np
import onnx
def preprocess_onnx(onnx_model):
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.
This function perform preprocess on the onnx model to solve the conflicts.
For example, onnx `attribute` is loaded in TensorRT on host and onnx
`input` is loaded on device. The shape inference is performed on host, so
any `input` related to shape (such as `max_output_boxes_per_class` in
NonMaxSuppression) should be transformed to `attribute` before conversion.
Arguments:
onnx_model (onnx.ModelProto): Input onnx model.
Returns:
onnx.ModelProto: Modified onnx model.
"""
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
node_dict = {}
for node in nodes:
node_outputs = node.output
for output in node_outputs:
if len(output) > 0:
node_dict[output] = node
init_dict = {_.name: _ for _ in initializers}
nodes_name_to_remove = set()
def is_node_without_output(name):
for node_name, node in node_dict.items():
if node_name not in nodes_name_to_remove:
if name in node.input:
return False
return True
def mark_nodes_to_remove(name):
node = node_dict[name]
nodes_name_to_remove.add(name)
for input_node_name in node.input:
if is_node_without_output(input_node_name):
mark_nodes_to_remove(input_node_name)
def parse_data(name, typ, default_value=0):
if name in node_dict:
node = node_dict[name]
if node.op_type == 'Constant':
raw_data = node.attribute[0].t.raw_data
else:
mark_nodes_to_remove(name)
return default_value
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initilizer.')
return np.frombuffer(raw_data, typ).item()
nrof_node = len(nodes)
for idx in range(nrof_node):
node = nodes[idx]
node_attributes = node.attribute
node_inputs = node.input
node_outputs = node.output
node_name = node.name
# process NonMaxSuppression node
if node.op_type == 'NonMaxSuppression':
center_point_box = 0
max_output_boxes_per_class = 1000000
iou_threshold = 0.3
score_threshold = 0.0
offset = 0
for attribute in node_attributes:
if attribute.name == 'center_point_box':
center_point_box = attribute.i
elif attribute.name == 'offset':
offset = attribute.i
if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64, max_output_boxes_per_class)
mark_nodes_to_remove(node_inputs[2])
if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32,
iou_threshold)
mark_nodes_to_remove(node_inputs[3])
if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)
mark_nodes_to_remove(node_inputs[4])
new_node = onnx.helper.make_node(
'NonMaxSuppression',
node_inputs[:2],
node_outputs,
name=node_name,
center_point_box=center_point_box,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
offset=offset)
for output in node_outputs:
if output in node_dict:
node_dict[output] = new_node
nodes.insert(idx, new_node)
nodes.remove(node)
elif node.op_type == 'InstanceNormalization':
# directly change op name
node.op_type = 'MMCVInstanceNormalization'
for node_name in nodes_name_to_remove:
nodes.remove(node_dict[node_name])
return onnx_model
import warnings import warnings
import numpy as np
import onnx import onnx
import tensorrt as trt import tensorrt as trt
import torch import torch
from .preprocess import preprocess_onnx
def preprocess_onnx(onnx_model):
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.
This function perform preprocess on the onnx model to solve the conflicts.
For example, onnx `attribute` is loaded in TensorRT on host and onnx
`input` is loaded on device. The shape inference is performed on host, so
any `input` related to shape (such as `max_output_boxes_per_class` in
NonMaxSuppression) should be transformed to `attribute` before conversion.
Arguments:
onnx_model (onnx.ModelProto): Input onnx model.
Returns:
onnx.ModelProto: Modified onnx model.
"""
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
node_dict = {}
for node in nodes:
node_outputs = node.output
for output in node_outputs:
if len(output) > 0:
node_dict[output] = node
init_dict = {_.name: _ for _ in initializers}
def parse_data(name, typ):
if name in node_dict:
const_node = node_dict[name]
assert const_node.op_type == 'Constant'
raw_data = const_node.attribute[0].t.raw_data
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initializer.')
return np.frombuffer(raw_data, typ).item()
nrof_node = len(nodes)
for idx in range(nrof_node):
node = nodes[idx]
node_attributes = node.attribute
node_inputs = node.input
node_outputs = node.output
node_name = node.name
# process NonMaxSuppression node
if node.op_type == 'NonMaxSuppression':
center_point_box = 0
max_output_boxes_per_class = 1000000
iou_threshold = 0.3
score_threshold = 0.0
offset = 0
for attribute in node_attributes:
if attribute.name == 'center_point_box':
center_point_box = attribute.i
elif attribute.name == 'offset':
offset = attribute.i
if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64)
if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32)
if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)
new_node = onnx.helper.make_node(
'NonMaxSuppression',
node_inputs[:2],
node_outputs,
name=node_name,
center_point_box=center_point_box,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
offset=offset)
for output in node_outputs:
if output in node_dict:
node_dict[output] = new_node
nodes.insert(idx, new_node)
nodes.remove(node)
elif node.op_type == 'InstanceNormalization':
# directly change op name
node.op_type = 'MMCVInstanceNormalization'
return onnx_model
def onnx2trt(onnx_model, def onnx2trt(onnx_model,
......
...@@ -138,7 +138,12 @@ class Testnms(object): ...@@ -138,7 +138,12 @@ class Testnms(object):
from mmcv.ops import batched_nms from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl') results = mmcv.load('./tests/data/batched_nms_data.pkl')
nms_cfg = dict(type='nms', iou_threshold=0.7) nms_max_num = 100
nms_cfg = dict(
type='nms',
iou_threshold=0.7,
score_threshold=0.5,
max_num=nms_max_num)
boxes, keep = batched_nms( boxes, keep = batched_nms(
torch.from_numpy(results['boxes']), torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']), torch.from_numpy(results['scores']),
...@@ -156,7 +161,8 @@ class Testnms(object): ...@@ -156,7 +161,8 @@ class Testnms(object):
assert torch.equal(keep, seq_keep) assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes) assert torch.equal(boxes, seq_boxes)
assert torch.equal(keep, torch.from_numpy(results['keep'])) assert torch.equal(keep,
torch.from_numpy(results['keep'][:nms_max_num]))
nms_cfg = dict(type='soft_nms', iou_threshold=0.7) nms_cfg = dict(type='soft_nms', iou_threshold=0.7)
boxes, keep = batched_nms( boxes, keep = batched_nms(
......
...@@ -93,9 +93,12 @@ def test_nms(): ...@@ -93,9 +93,12 @@ def test_nms():
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes) boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores) scores = torch.from_numpy(np_scores)
pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
nms = partial(
nms, iou_threshold=0.3, offset=0, score_threshold=0, max_num=0)
pytorch_dets, _ = nms(boxes, scores)
pytorch_score = pytorch_dets[:, 4] pytorch_score = pytorch_dets[:, 4]
nms = partial(nms, iou_threshold=0.3, offset=0)
wrapped_model = WrapFunction(nms) wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval() wrapped_model.cpu().eval()
with torch.no_grad(): with torch.no_grad():
...@@ -106,14 +109,12 @@ def test_nms(): ...@@ -106,14 +109,12 @@ def test_nms():
keep_initializers_as_inputs=True, keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'], input_names=['boxes', 'scores'],
opset_version=11) opset_version=11)
onnx_model = onnx.load(onnx_file)
onnx_model = onnx.load(onnx_file)
ort_custom_op_path = get_onnxruntime_op_path() ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('nms for onnxruntime is not compiled.')
session_options = rt.SessionOptions() session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path) if os.path.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output # get onnx output
input_all = [node.name for node in onnx_model.graph.input] input_all = [node.name for node in onnx_model.graph.input]
......
...@@ -126,7 +126,8 @@ def test_nms(): ...@@ -126,7 +126,8 @@ def test_nms():
data = mmcv.load('./tests/data/batched_nms_data.pkl') data = mmcv.load('./tests/data/batched_nms_data.pkl')
boxes = torch.from_numpy(data['boxes']).cuda() boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda() scores = torch.from_numpy(data['scores']).cuda()
nms = partial(nms, iou_threshold=0.7, offset=0) nms = partial(
nms, iou_threshold=0.7, offset=0, score_threshold=0.1, max_num=100)
wrapped_model = WrapFunction(nms) wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval() wrapped_model.cpu().eval()
with torch.no_grad(): with torch.no_grad():
...@@ -195,7 +196,7 @@ def test_batched_nms(): ...@@ -195,7 +196,7 @@ def test_batched_nms():
fp16_mode = False fp16_mode = False
max_workspace_size = 1 << 30 max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl') data = mmcv.load('./tests/data/batched_nms_data.pkl')
nms_cfg = dict(type='nms', iou_threshold=0.7) nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1)
boxes = torch.from_numpy(data['boxes']).cuda() boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda() scores = torch.from_numpy(data['scores']).cuda()
idxs = torch.from_numpy(data['idxs']).cuda() idxs = torch.from_numpy(data['idxs']).cuda()
......
import os
from functools import wraps
import onnx
import torch
from mmcv.ops import nms
from mmcv.tensorrt.preprocess import preprocess_onnx
def remove_tmp_file(func):
@wraps(func)
def wrapper(*args, **kwargs):
onnx_file = 'tmp.onnx'
kwargs['onnx_file'] = onnx_file
try:
result = func(*args, **kwargs)
finally:
if os.path.exists(onnx_file):
os.remove(onnx_file)
return result
return wrapper
@remove_tmp_file
def export_nms_module_to_onnx(module, onnx_file):
torch_model = module()
torch_model.eval()
input = (torch.rand([100, 4], dtype=torch.float32),
torch.rand([100], dtype=torch.float32))
torch.onnx.export(
torch_model,
input,
onnx_file,
opset_version=11,
input_names=['boxes', 'scores'],
output_names=['output'])
onnx_model = onnx.load(onnx_file)
return onnx_model
def test_can_handle_nms_with_constant_maxnum():
class ModuleNMS(torch.nn.Module):
def forward(self, boxes, scores):
return nms(boxes, scores, iou_threshold=0.4, max_num=10)
onnx_model = export_nms_module_to_onnx(ModuleNMS)
preprocess_onnx_model = preprocess_onnx(onnx_model)
for node in preprocess_onnx_model.graph.node:
if 'NonMaxSuppression' in node.name:
assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'
def test_can_handle_nms_with_undefined_maxnum():
class ModuleNMS(torch.nn.Module):
def forward(self, boxes, scores):
return nms(boxes, scores, iou_threshold=0.4)
onnx_model = export_nms_module_to_onnx(ModuleNMS)
preprocess_onnx_model = preprocess_onnx(onnx_model)
for node in preprocess_onnx_model.graph.node:
if 'NonMaxSuppression' in node.name:
assert len(node.attribute) == 5, \
'The NMS must have 5 attributes.'
assert node.attribute[2].i > 0, \
'The max_output_boxes_per_class is not defined correctly.'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment