Unverified Commit ae1d7071 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Cleanup ops (#6024)



* Cleanup ops

* Address nits
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 49ee65f0
...@@ -5,13 +5,13 @@ from .boxes import ( ...@@ -5,13 +5,13 @@ from .boxes import (
remove_small_boxes, remove_small_boxes,
clip_boxes_to_image, clip_boxes_to_image,
box_area, box_area,
box_convert,
box_iou, box_iou,
generalized_box_iou, generalized_box_iou,
distance_box_iou, distance_box_iou,
complete_box_iou, complete_box_iou,
masks_to_boxes, masks_to_boxes,
) )
from .boxes import box_convert
from .ciou_loss import complete_box_iou_loss from .ciou_loss import complete_box_iou_loss
from .deform_conv import deform_conv2d, DeformConv2d from .deform_conv import deform_conv2d, DeformConv2d
from .diou_loss import distance_box_iou_loss from .diou_loss import distance_box_iou_loss
......
...@@ -67,3 +67,40 @@ def split_normalization_params( ...@@ -67,3 +67,40 @@ def split_normalization_params(
else: else:
other_params.extend(p for p in module.parameters() if p.requires_grad) other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params return norm_params, other_params
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()
def _upcast_non_float(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.dtype not in (torch.float32, torch.float64):
return t.float()
return t
def _loss_inter_union(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
return intsctk, unionk
...@@ -7,6 +7,7 @@ from torchvision.extension import _assert_has_ops ...@@ -7,6 +7,7 @@ from torchvision.extension import _assert_has_ops
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
from ._utils import _upcast
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
...@@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: ...@@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes return boxes
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()
def box_area(boxes: Tensor) -> Tensor: def box_area(boxes: Tensor) -> Tensor:
""" """
Computes the area of a set of bounding boxes, which are specified by their Computes the area of a set of bounding boxes, which are specified by their
...@@ -330,22 +323,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso ...@@ -330,22 +323,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
boxes1 = _upcast(boxes1) boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2) boxes2 = _upcast(boxes2)
inter, union = _box_inter_union(boxes1, boxes2) diou, iou = _box_diou_iou(boxes1, boxes2, eps)
iou = inter / union
lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:])
whi = (rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
w_pred = boxes1[:, 2] - boxes1[:, 0] w_pred = boxes1[:, 2] - boxes1[:, 0]
h_pred = boxes1[:, 3] - boxes1[:, 1] h_pred = boxes1[:, 3] - boxes1[:, 1]
...@@ -356,7 +334,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso ...@@ -356,7 +334,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad(): with torch.no_grad():
alpha = v / (1 - iou + v + eps) alpha = v / (1 - iou + v + eps)
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v return diou - alpha * v
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
...@@ -380,16 +358,17 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso ...@@ -380,16 +358,17 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
boxes1 = _upcast(boxes1) boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2) boxes2 = _upcast(boxes2)
diou, _ = _box_diou_iou(boxes1, boxes2)
return diou
inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:
iou = box_iou(boxes1, boxes2)
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
# centers of boxes # centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
...@@ -397,10 +376,9 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso ...@@ -397,10 +376,9 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared. # The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
# The distance IoU is the IoU penalized by a normalized # The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared. # distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared) return iou - (centers_distance_squared / diagonal_distance_squared), iou
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
......
import torch import torch
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from .giou_loss import _upcast from ._utils import _upcast_non_float
from .diou_loss import _diou_iou_loss
def complete_box_iou_loss( def complete_box_iou_loss(
...@@ -30,50 +31,28 @@ def complete_box_iou_loss( ...@@ -30,50 +31,28 @@ def complete_box_iou_loss(
``'sum'``: The output will be summed. Default: ``'none'`` ``'sum'``: The output will be summed. Default: ``'none'``
eps : (float): small number to prevent division by zero. Default: 1e-7 eps : (float): small number to prevent division by zero. Default: 1e-7
Reference: Returns:
Tensor: Loss tensor with the reduction option applied.
Complete Intersection over Union Loss (Zhaohui Zheng et. al) Reference:
https://arxiv.org/abs/1911.08287 Zhaohui Zheng et. al: Complete Intersection over Union Loss:
https://arxiv.org/abs/1911.08287
""" """
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(complete_box_iou_loss) _log_api_usage_once(complete_box_iou_loss)
boxes1 = _upcast(boxes1) boxes1 = _upcast_non_float(boxes1)
boxes2 = _upcast(boxes2) boxes2 = _upcast_non_float(boxes2)
diou_loss, iou = _diou_iou_loss(boxes1, boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
# width and height of boxes # width and height of boxes
w_pred = x2 - x1 w_pred = x2 - x1
h_pred = y2 - y1 h_pred = y2 - y1
...@@ -83,7 +62,7 @@ def complete_box_iou_loss( ...@@ -83,7 +62,7 @@ def complete_box_iou_loss(
with torch.no_grad(): with torch.no_grad():
alpha = v / (1 - iou + v + eps) alpha = v / (1 - iou + v + eps)
loss = 1 - iou + (distance / diag_len) + alpha * v loss = diou_loss + alpha * v
if reduction == "mean": if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum": elif reduction == "sum":
......
from typing import Tuple
import torch import torch
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from .boxes import _upcast from ._utils import _loss_inter_union, _upcast_non_float
def distance_box_iou_loss( def distance_box_iou_loss(
...@@ -10,6 +12,7 @@ def distance_box_iou_loss( ...@@ -10,6 +12,7 @@ def distance_box_iou_loss(
reduction: str = "none", reduction: str = "none",
eps: float = 1e-7, eps: float = 1e-7,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Gradient-friendly IoU loss with an additional penalty that is non-zero when the Gradient-friendly IoU loss with an additional penalty that is non-zero when the
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
...@@ -37,37 +40,40 @@ def distance_box_iou_loss( ...@@ -37,37 +40,40 @@ def distance_box_iou_loss(
https://arxiv.org/abs/1911.08287 https://arxiv.org/abs/1911.08287
""" """
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou_loss) _log_api_usage_once(distance_box_iou_loss)
boxes1 = _upcast(boxes1) boxes1 = _upcast_non_float(boxes1)
boxes2 = _upcast(boxes2) boxes2 = _upcast_non_float(boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1) loss, _ = _diou_iou_loss(boxes1, boxes2, eps)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints if reduction == "mean":
xkis1 = torch.max(x1, x1g) loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
ykis1 = torch.max(y1, y1g) elif reduction == "sum":
xkis2 = torch.min(x2, x2g) loss = loss.sum()
ykis2 = torch.min(y2, y2g) return loss
intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union
def _diou_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
eps: float = 1e-7,
) -> Tuple[torch.Tensor, torch.Tensor]:
intsct, union = _loss_inter_union(boxes1, boxes2)
iou = intsct / (union + eps)
# smallest enclosing box # smallest enclosing box
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
xc1 = torch.min(x1, x1g) xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g) yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g) xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g) yc2 = torch.max(y2, y2g)
# The diagonal distance of the smallest enclosing box squared # The diagonal distance of the smallest enclosing box squared
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
# centers of boxes # centers of boxes
x_p = (x2 + x1) / 2 x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2 y_p = (y2 + y1) / 2
...@@ -75,12 +81,7 @@ def distance_box_iou_loss( ...@@ -75,12 +81,7 @@ def distance_box_iou_loss(
y_g = (y1g + y2g) / 2 y_g = (y1g + y2g) / 2
# The distance between boxes' centers squared. # The distance between boxes' centers squared.
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
# The distance IoU is the IoU penalized by a normalized # The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared. # distance between boxes' centers squared.
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared) loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
if reduction == "mean": return loss, iou
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
import torch import torch
from torch import Tensor
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import _upcast_non_float, _loss_inter_union
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.dtype not in (torch.float32, torch.float64):
return t.float()
return t
def generalized_box_iou_loss( def generalized_box_iou_loss(
...@@ -17,10 +10,8 @@ def generalized_box_iou_loss( ...@@ -17,10 +10,8 @@ def generalized_box_iou_loss(
reduction: str = "none", reduction: str = "none",
eps: float = 1e-7, eps: float = 1e-7,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Original implementation from
https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py
"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap and scales with the size of their smallest enclosing box. boxes do not overlap and scales with the size of their smallest enclosing box.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
...@@ -38,31 +29,28 @@ def generalized_box_iou_loss( ...@@ -38,31 +29,28 @@ def generalized_box_iou_loss(
``'sum'``: The output will be summed. Default: ``'none'`` ``'sum'``: The output will be summed. Default: ``'none'``
eps (float): small number to prevent division by zero. Default: 1e-7 eps (float): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor: Loss tensor with the reduction option applied.
Reference: Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union: Hamid Rezatofighi et. al: Generalized Intersection over Union:
A Metric and A Loss for Bounding Box Regression: A Metric and A Loss for Bounding Box Regression:
https://arxiv.org/abs/1902.09630 https://arxiv.org/abs/1902.09630
""" """
# Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(generalized_box_iou_loss) _log_api_usage_once(generalized_box_iou_loss)
boxes1 = _upcast(boxes1) boxes1 = _upcast_non_float(boxes1)
boxes2 = _upcast(boxes2) boxes2 = _upcast_non_float(boxes2)
intsctk, unionk = _loss_inter_union(boxes1, boxes2)
iouk = intsctk / (unionk + eps)
x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
iouk = intsctk / (unionk + eps)
# smallest enclosing box # smallest enclosing box
xc1 = torch.min(x1, x1g) xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g) yc1 = torch.min(y1, y1g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment