Unverified Commit 5320f742 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Adds torchscript Compatibility to box_convert (#2737)

* fixies small bug in box_convert

* activates jit test

* Passes JIT test

* fixes typo

* adds error tests, removes assert

* implements to proposal2
parent 6e639d3e
...@@ -727,20 +727,30 @@ class BoxTester(unittest.TestCase): ...@@ -727,20 +727,30 @@ class BoxTester(unittest.TestCase):
self.assertEqual(box_xywh.dtype, box_tensor.dtype) self.assertEqual(box_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, box_tensor)).item() assert torch.all(torch.eq(box_xywh, box_tensor)).item()
# def test_bbox_convert_jit(self): def test_bbox_invalid(self):
# box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
# [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
invalid_infmts = ["xwyh", "cxwyh"]
invalid_outfmts = ["xwcx", "xhwcy"]
for inv_infmt in invalid_infmts:
for inv_outfmt in invalid_outfmts:
self.assertRaises(ValueError, ops.box_convert, box_tensor, inv_infmt, inv_outfmt)
def test_bbox_convert_jit(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
# scripted_fn = torch.jit.script(ops.box_convert) scripted_fn = torch.jit.script(ops.box_convert)
# TOLERANCE = 1e-3 TOLERANCE = 1e-3
# box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
# scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh') scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
# self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE) self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
# box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
# scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh') scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
# self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE) self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
class BoxAreaTester(unittest.TestCase): class BoxAreaTester(unittest.TestCase):
......
...@@ -77,7 +77,7 @@ def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor: ...@@ -77,7 +77,7 @@ def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor:
boxes (Tensor[N, 4]): boxes in (x, y, w, h) format. boxes (Tensor[N, 4]): boxes in (x, y, w, h) format.
""" """
x1, y1, x2, y2 = boxes.unbind(-1) x1, y1, x2, y2 = boxes.unbind(-1)
x2 = x2 - x1 # x2 - x1 w = x2 - x1 # x2 - x1
y2 = y2 - y1 # y2 - y1 h = y2 - y1 # y2 - y1
boxes = torch.stack((x1, y1, x2, y2), dim=-1) boxes = torch.stack((x1, y1, w, h), dim=-1)
return boxes return boxes
...@@ -154,39 +154,33 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: ...@@ -154,39 +154,33 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
Returns: Returns:
boxes (Tensor[N, 4]): Boxes into converted format. boxes (Tensor[N, 4]): Boxes into converted format.
""" """
allowed_fmts = ("xyxy", "xywh", "cxcywh") allowed_fmts = ("xyxy", "xywh", "cxcywh")
assert in_fmt in allowed_fmts if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
assert out_fmt in allowed_fmts raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
if in_fmt == out_fmt: if in_fmt == out_fmt:
boxes_converted = boxes.clone() return boxes.clone()
return boxes_converted
if in_fmt != 'xyxy' and out_fmt != 'xyxy': if in_fmt != 'xyxy' and out_fmt != 'xyxy':
# convert to xyxy and change in_fmt xyxy
if in_fmt == "xywh": if in_fmt == "xywh":
boxes_xyxy = _box_xywh_to_xyxy(boxes) boxes = _box_xywh_to_xyxy(boxes)
if out_fmt == "cxcywh":
boxes_converted = _box_xyxy_to_cxcywh(boxes_xyxy)
elif in_fmt == "cxcywh": elif in_fmt == "cxcywh":
boxes_xyxy = _box_cxcywh_to_xyxy(boxes) boxes = _box_cxcywh_to_xyxy(boxes)
if out_fmt == "xywh": in_fmt = 'xyxy'
boxes_converted = _box_xyxy_to_xywh(boxes_xyxy)
if in_fmt == "xyxy":
# convert one to xyxy and change either in_fmt or out_fmt to xyxy if out_fmt == "xywh":
else: boxes = _box_xyxy_to_xywh(boxes)
if in_fmt == "xyxy": elif out_fmt == "cxcywh":
if out_fmt == "xywh": boxes = _box_xyxy_to_cxcywh(boxes)
boxes_converted = _box_xyxy_to_xywh(boxes) elif out_fmt == "xyxy":
elif out_fmt == "cxcywh": if in_fmt == "xywh":
boxes_converted = _box_xyxy_to_cxcywh(boxes) boxes = _box_xywh_to_xyxy(boxes)
elif out_fmt == "xyxy": elif in_fmt == "cxcywh":
if in_fmt == "xywh": boxes = _box_cxcywh_to_xyxy(boxes)
boxes_converted = _box_xywh_to_xyxy(boxes) return boxes
elif in_fmt == "cxcywh":
boxes_converted = _box_cxcywh_to_xyxy(boxes)
return boxes_converted
def box_area(boxes: Tensor) -> Tensor: def box_area(boxes: Tensor) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment