Unverified Commit 5320f742 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Adds torchscript Compatibility to box_convert (#2737)

* fixies small bug in box_convert

* activates jit test

* Passes JIT test

* fixes typo

* adds error tests, removes assert

* implements to proposal2
parent 6e639d3e
......@@ -727,20 +727,30 @@ class BoxTester(unittest.TestCase):
self.assertEqual(box_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, box_tensor)).item()
# def test_bbox_convert_jit(self):
# box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
# [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
def test_bbox_invalid(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
invalid_infmts = ["xwyh", "cxwyh"]
invalid_outfmts = ["xwcx", "xhwcy"]
for inv_infmt in invalid_infmts:
for inv_outfmt in invalid_outfmts:
self.assertRaises(ValueError, ops.box_convert, box_tensor, inv_infmt, inv_outfmt)
def test_bbox_convert_jit(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
# scripted_fn = torch.jit.script(ops.box_convert)
# TOLERANCE = 1e-3
scripted_fn = torch.jit.script(ops.box_convert)
TOLERANCE = 1e-3
# box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
# scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
# self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
# box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
# scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
# self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
class BoxAreaTester(unittest.TestCase):
......
......@@ -77,7 +77,7 @@ def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor:
boxes (Tensor[N, 4]): boxes in (x, y, w, h) format.
"""
x1, y1, x2, y2 = boxes.unbind(-1)
x2 = x2 - x1 # x2 - x1
y2 = y2 - y1 # y2 - y1
boxes = torch.stack((x1, y1, x2, y2), dim=-1)
w = x2 - x1 # x2 - x1
h = y2 - y1 # y2 - y1
boxes = torch.stack((x1, y1, w, h), dim=-1)
return boxes
......@@ -154,39 +154,33 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
Returns:
boxes (Tensor[N, 4]): Boxes into converted format.
"""
allowed_fmts = ("xyxy", "xywh", "cxcywh")
assert in_fmt in allowed_fmts
assert out_fmt in allowed_fmts
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
if in_fmt == out_fmt:
boxes_converted = boxes.clone()
return boxes_converted
return boxes.clone()
if in_fmt != 'xyxy' and out_fmt != 'xyxy':
# convert to xyxy and change in_fmt xyxy
if in_fmt == "xywh":
boxes_xyxy = _box_xywh_to_xyxy(boxes)
if out_fmt == "cxcywh":
boxes_converted = _box_xyxy_to_cxcywh(boxes_xyxy)
boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes_xyxy = _box_cxcywh_to_xyxy(boxes)
if out_fmt == "xywh":
boxes_converted = _box_xyxy_to_xywh(boxes_xyxy)
boxes = _box_cxcywh_to_xyxy(boxes)
in_fmt = 'xyxy'
# convert one to xyxy and change either in_fmt or out_fmt to xyxy
else:
if in_fmt == "xyxy":
if out_fmt == "xywh":
boxes_converted = _box_xyxy_to_xywh(boxes)
boxes = _box_xyxy_to_xywh(boxes)
elif out_fmt == "cxcywh":
boxes_converted = _box_xyxy_to_cxcywh(boxes)
boxes = _box_xyxy_to_cxcywh(boxes)
elif out_fmt == "xyxy":
if in_fmt == "xywh":
boxes_converted = _box_xywh_to_xyxy(boxes)
boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh":
boxes_converted = _box_cxcywh_to_xyxy(boxes)
return boxes_converted
boxes = _box_cxcywh_to_xyxy(boxes)
return boxes
def box_area(boxes: Tensor) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment