"configs/vscode:/vscode.git/clone" did not exist on "1a3a1adea3975adbd2e27770bf174cff3c7a3df3"
Unverified Commit 48d99025 authored by z55250825's avatar z55250825 Committed by GitHub
Browse files

Add new parrots extension implementation for all ops (#794)

* delete all parrots file
add bbox_overlaps new parrots op impl

* support first new impl parrts op (bbox_overlaps)(success test)

* add box_iou_rotated op, test succeed

* add carafe and carafe_naive op, test succeed (one parrots bug need fix)

* add cc_attention op, test success

* add corner_pool op, test success

* add parrots op deform_conv, test success

* add deform_roi_pool op, test success (but has question)

* add focal loss op, test success (gradcheck)

* add masked_conv2d op, test success

* add modulated_deform_conv op, test success

* add nms and nms_rotated op, test success

* add psamask op, test success

* add roi_align op, test_success

* add roi_pool op, test success

* add sync_bn op, test success

* add tin_shift op, test success

* fix test_deform_roi_pool, add parrots test

* skip test_onnx because parrots does not support onnx

* fix c++ lint

* fix python lint

* fix python lint
parent 72e4cc12
...@@ -140,20 +140,20 @@ class Testnms(object): ...@@ -140,20 +140,20 @@ class Testnms(object):
nms_cfg = dict(type='nms', iou_threshold=0.7) nms_cfg = dict(type='nms', iou_threshold=0.7)
boxes, keep = batched_nms( boxes, keep = batched_nms(
results['boxes'], torch.from_numpy(results['boxes']),
results['scores'], torch.from_numpy(results['scores']),
results['idxs'], torch.from_numpy(results['idxs']),
nms_cfg, nms_cfg,
class_agnostic=False) class_agnostic=False)
nms_cfg.update(split_thr=100) nms_cfg.update(split_thr=100)
seq_boxes, seq_keep = batched_nms( seq_boxes, seq_keep = batched_nms(
results['boxes'], torch.from_numpy(results['boxes']),
results['scores'], torch.from_numpy(results['scores']),
results['idxs'], torch.from_numpy(results['idxs']),
nms_cfg, nms_cfg,
class_agnostic=False) class_agnostic=False)
assert torch.equal(keep, seq_keep) assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes) assert torch.equal(boxes, seq_boxes)
assert torch.equal(keep, results['keep']) assert torch.equal(keep, torch.from_numpy(results['keep']))
...@@ -24,6 +24,8 @@ class WrapFunction(nn.Module): ...@@ -24,6 +24,8 @@ class WrapFunction(nn.Module):
def test_nms(): def test_nms():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.ops import get_onnxruntime_op_path, nms from mmcv.ops import get_onnxruntime_op_path, nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
...@@ -70,6 +72,8 @@ def test_nms(): ...@@ -70,6 +72,8 @@ def test_nms():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU') @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_softnms(): def test_softnms():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.ops import get_onnxruntime_op_path, soft_nms from mmcv.ops import get_onnxruntime_op_path, soft_nms
# only support pytorch >= 1.7.0 # only support pytorch >= 1.7.0
...@@ -144,6 +148,8 @@ def test_softnms(): ...@@ -144,6 +148,8 @@ def test_softnms():
def test_roialign(): def test_roialign():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
try: try:
from mmcv.ops import roi_align from mmcv.ops import roi_align
from mmcv.ops import get_onnxruntime_op_path from mmcv.ops import get_onnxruntime_op_path
...@@ -216,6 +222,8 @@ def test_roialign(): ...@@ -216,6 +222,8 @@ def test_roialign():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU') @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_roipool(): def test_roipool():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.ops import roi_pool from mmcv.ops import roi_pool
# roi pool config # roi pool config
...@@ -278,6 +286,8 @@ def test_roipool(): ...@@ -278,6 +286,8 @@ def test_roipool():
def test_simplify(): def test_simplify():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.onnx.simplify import simplify from mmcv.onnx.simplify import simplify
# only support PyTorch >= 1.5.0 # only support PyTorch >= 1.5.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment