Unverified Commit 1ae38297 authored by Yassine Alouini's avatar Yassine Alouini Committed by GitHub
Browse files

Distance IoU (#5786)



* [FEAT] Add distance IoU and distance IoU loss + some tests (WIP for tests).

* [FIX] Remove URL from docstring + remove assert since it causes a big performance drop.

* [FIX] eps isn't None.

* [TEST] Update existing box dIoU test + add dIoU loss tests (inspired from cIoU ones).

* [ENH] Some pre-commit fixes + remove print + mypy.

* [ENH] Pass the device in the assertion for the dIoU loss test.

* [FIX] Remove type hints from the dIoU box test.

* [ENH] Refactor box and loss for dIoU functions + fix half tests.

* [FIX] Precommits fix.

* [ENH] Some improvement for the distance IoU tests thanks to code review.

* [ENH] Upcast in distance boxes computation to avoid overflow.

* [ENH] Revert the refactor of distance IoU loss back since it introduced a bug and can be slow.

* Precommit fix.

* [FIX] Few changes introduced by merge conflict.

* Add code reference

* Fix test
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent c1868581
......@@ -26,6 +26,8 @@ Operators
drop_block3d
generalized_box_iou
generalized_box_iou_loss
distance_box_iou
distance_box_iou_loss
masks_to_boxes
nms
ps_roi_align
......
......@@ -1258,6 +1258,97 @@ class TestGenBoxIou(BoxTestBase):
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
class TestDistanceBoxIoU(BoxTestBase):
def _target_fn(self):
return (True, ops.distance_box_iou)
def _generate_int_input():
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
def _generate_float_input():
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]
def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
],
)
def test_distance_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(test_input, dtypes, tolerance, expected)
def test_distance_iou_jit(self):
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_distance_iou_loss(dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
box1s = torch.stack(
[box2, box2],
dim=0,
)
box2s = torch.stack(
[box3, box4],
dim=0,
)
def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"):
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction)
# TODO: When passing the dtype, the torch.half fails as usual.
expected_output = torch.tensor(expected_output, device=device)
tol = 1e-5 if dtype != torch.half else 1e-3
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)
assert_distance_iou_loss(box1, box1, 0.0)
assert_distance_iou_loss(box1, box2, 0.8125)
assert_distance_iou_loss(box1, box3, 1.1923)
assert_distance_iou_loss(box1, box4, 1.2500)
assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean")
assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum")
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_distance_iou_inputs(dtype, device) -> None:
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
loss = ops.distance_box_iou_loss(box1, box2, reduction="mean")
loss.backward()
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol)
assert box1.grad is not None, "box1.grad should not be None after backward is called"
assert box2.grad is not None, "box2.grad should not be None after backward is called"
loss = ops.distance_box_iou_loss(box1, box2, reduction="none")
assert loss.numel() == 0, "diou_loss for two empty box should be empty"
class TestCompleteBoxIou(BoxTestBase):
def _target_fn(self) -> Tuple[bool, Callable]:
return (True, ops.complete_box_iou)
......@@ -1676,6 +1767,7 @@ class TestCIOULoss:
def assert_ciou_loss(box1, box2, expected_output, reduction="none"):
output = ops.complete_box_iou_loss(box1, box2, reduction=reduction)
# TODO: When passing the dtype, the torch.half test doesn't pass...
expected_output = torch.tensor(expected_output, device=device)
tol = 1e-5 if dtype != torch.half else 1e-3
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)
......
......@@ -7,12 +7,14 @@ from .boxes import (
box_area,
box_iou,
generalized_box_iou,
distance_box_iou,
complete_box_iou,
masks_to_boxes,
)
from .boxes import box_convert
from .ciou_loss import complete_box_iou_loss
from .deform_conv import deform_conv2d, DeformConv2d
from .diou_loss import distance_box_iou_loss
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
......@@ -40,6 +42,8 @@ __all__ = [
"box_area",
"box_iou",
"generalized_box_iou",
"distance_box_iou",
"complete_box_iou",
"roi_align",
"RoIAlign",
"roi_pool",
......@@ -58,6 +62,8 @@ __all__ = [
"Conv3dNormActivation",
"SqueezeExcitation",
"generalized_box_iou_loss",
"distance_box_iou_loss",
"complete_box_iou_loss",
"drop_block2d",
"DropBlock2d",
"drop_block3d",
......
......@@ -359,6 +359,50 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
"""
Return distance intersection-over-union (Jaccard index) between two sets of boxes.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou)
boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared)
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks.
......
import torch
from ..utils import _log_api_usage_once
from .boxes import _upcast
def distance_box_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
boxes, the distance IoU is the same as the IoU loss.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
same dimensions.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[N, 4]): second set of boxes
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor: Loss tensor with the reduction option applied.
Reference:
Zhaohui Zheng et. al: Distance Intersection over Union Loss:
https://arxiv.org/abs/1911.08287
"""
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou_loss)
boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
# The diagonal distance of the smallest enclosing box squared
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
# The distance between boxes' centers squared.
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
......@@ -36,7 +36,7 @@ def generalized_box_iou_loss(
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7
eps (float): small number to prevent division by zero. Default: 1e-7
Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment