Unverified Commit ecbff88a authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

Added CIOU loss function (#5776)



* added ciou loss

* "formatting with flake8 and ufmt"

* formatting with ufmt and flake8

* minor changes

* changes as per the suggestions

* added reference in torchvision/ops/__init__.py

* sample test

* tests formatted

* added description

* formatting

* edited tests

* changes in tests

* added tests for multiple boxes

* minor edits

* minor edit

* doc added

* minor edits

* Update test_ops.py

* formatting test file

* changes as per the suggestions

* formatting and adding some more tests

* bounding box added

* removed unnecessary comment

* added docstring

* added type annotations

* removed potential bug

* Update torchvision/ops/boxes.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/ops/boxes.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update test/test_ops.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 5fc36b4f
...@@ -19,6 +19,8 @@ Operators ...@@ -19,6 +19,8 @@ Operators
box_convert box_convert
box_iou box_iou
clip_boxes_to_image clip_boxes_to_image
complete_box_iou
complete_box_iou_loss
deform_conv2d deform_conv2d
drop_block2d drop_block2d
drop_block3d drop_block3d
......
...@@ -1258,6 +1258,43 @@ class TestGenBoxIou(BoxTestBase): ...@@ -1258,6 +1258,43 @@ class TestGenBoxIou(BoxTestBase):
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
class TestCompleteBoxIou(BoxTestBase):
def _target_fn(self) -> Tuple[bool, Callable]:
return (True, ops.complete_box_iou)
def _generate_int_input() -> List[List[int]]:
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
def _generate_int_expected() -> List[List[float]]:
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
def _generate_float_input() -> List[List[float]]:
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]
def _generate_float_expected() -> List[List[float]]:
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
],
)
def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
self._run_test(test_input, dtypes, tolerance, expected)
def test_ciou_jit(self) -> None:
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
class TestMasksToBoxes: class TestMasksToBoxes:
def test_masks_box(self): def test_masks_box(self):
def masks_box_check(masks, expected, tolerance=1e-4): def masks_box_check(masks, expected, tolerance=1e-4):
...@@ -1578,6 +1615,7 @@ class TestGeneralizedBoxIouLoss: ...@@ -1578,6 +1615,7 @@ class TestGeneralizedBoxIouLoss:
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
box1s = torch.stack([box2, box2], dim=0) box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0) box2s = torch.stack([box3, box4], dim=0)
...@@ -1623,5 +1661,53 @@ class TestGeneralizedBoxIouLoss: ...@@ -1623,5 +1661,53 @@ class TestGeneralizedBoxIouLoss:
assert loss.numel() == 0, "giou_loss for two empty box should be empty" assert loss.numel() == 0, "giou_loss for two empty box should be empty"
class TestCIOULoss:
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_ciou_loss(self, dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0)
def assert_ciou_loss(box1, box2, expected_output, reduction="none"):
output = ops.complete_box_iou_loss(box1, box2, reduction=reduction)
expected_output = torch.tensor(expected_output, device=device)
tol = 1e-5 if dtype != torch.half else 1e-3
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)
assert_ciou_loss(box1, box1, 0.0)
assert_ciou_loss(box1, box2, 0.8125)
assert_ciou_loss(box1, box3, 1.1923)
assert_ciou_loss(box1, box4, 1.2500)
assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean")
assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum")
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device) -> None:
box1 = torch.randn([0, 4], dtype=dtype).requires_grad_()
box2 = torch.randn([0, 4], dtype=dtype).requires_grad_()
loss = ops.complete_box_iou_loss(box1, box2, reduction="mean")
loss.backward()
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol)
assert box1.grad is not None, "box1.grad should not be None after backward is called"
assert box2.grad is not None, "box2.grad should not be None after backward is called"
loss = ops.complete_box_iou_loss(box1, box2, reduction="none")
assert loss.numel() == 0, "ciou_loss for two empty box should be empty"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -7,9 +7,11 @@ from .boxes import ( ...@@ -7,9 +7,11 @@ from .boxes import (
box_area, box_area,
box_iou, box_iou,
generalized_box_iou, generalized_box_iou,
complete_box_iou,
masks_to_boxes, masks_to_boxes,
) )
from .boxes import box_convert from .boxes import box_convert
from .ciou_loss import complete_box_iou_loss
from .deform_conv import deform_conv2d, DeformConv2d from .deform_conv import deform_conv2d, DeformConv2d
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
......
...@@ -311,6 +311,54 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ...@@ -311,6 +311,54 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
return iou - (areai - union) / areai return iou - (areai - union) / areai
def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
"""
Return complete intersection-over-union (Jaccard index) between two sets of boxes.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(complete_box_iou)
boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union
lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:])
whi = (rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
w_pred = boxes1[:, 2] - boxes1[:, 0]
h_pred = boxes1[:, 3] - boxes1[:, 1]
w_gt = boxes2[:, 2] - boxes2[:, 0]
h_gt = boxes2[:, 3] - boxes2[:, 1]
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
""" """
Compute the bounding boxes around the provided masks. Compute the bounding boxes around the provided masks.
......
import torch
from ..utils import _log_api_usage_once
from .giou_loss import _upcast
def complete_box_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap overlap area, This loss function considers important geometrical
factors such as overlap area, normalized central point distance and aspect ratio.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
same dimensions.
Args:
boxes1 : (Tensor[N, 4] or Tensor[4]) first set of boxes
boxes2 : (Tensor[N, 4] or Tensor[4]) second set of boxes
reduction : (string, optional) Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps : (float): small number to prevent division by zero. Default: 1e-7
Reference:
Complete Intersection over Union Loss (Zhaohui Zheng et. al)
https://arxiv.org/abs/1911.08287
"""
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(complete_box_iou_loss)
boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
# width and height of boxes
w_pred = x2 - x1
h_pred = y2 - y1
w_gt = x2g - x1g
h_gt = y2g - y1g
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
loss = 1 - iou + (distance / diag_len) + alpha * v
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment