Unverified Commit 64b1e279 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding ciou and diou support in `_box_loss()` (#5984)

* Adding ciou and diou support in `_box_loss()`

* Fix linter

* Addressing comments for nits
parent 3ec4b949
......@@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss
from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss
class BalancedPositiveNegativeSampler:
......@@ -518,7 +518,7 @@ def _box_loss(
bbox_regression_per_image: Tensor,
cnf: Optional[Dict[str, float]] = None,
) -> Tensor:
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")
torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
if type == "l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
......@@ -527,7 +527,12 @@ def _box_loss(
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
else: # giou
else:
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
if type == "ciou":
return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
if type == "diou":
return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
# otherwise giou
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment