# Copyright 2021 Toyota Research Institute. All rights reserved. # Adapted from AdelaiDet: # https://github.com/aim-uofa/AdelaiDet/blob/master/adet/layers/iou_loss.py import torch from torch import nn class IOULoss(nn.Module): """ Intersetion Over Union (IoU) loss which supports three different IoU computations: * IoU * Linear IoU * gIoU """ def __init__(self, loc_loss_type='iou'): super(IOULoss, self).__init__() self.loc_loss_type = loc_loss_type def forward(self, pred, target, weight=None): """ Args: pred: Nx4 predicted bounding boxes target: Nx4 target bounding boxes weight: N loss weight for each instance """ pred_left = pred[:, 0] pred_top = pred[:, 1] pred_right = pred[:, 2] pred_bottom = pred[:, 3] target_left = target[:, 0] target_top = target[:, 1] target_right = target[:, 2] target_bottom = target[:, 3] target_aera = (target_left + target_right) * \ (target_top + target_bottom) pred_aera = (pred_left + pred_right) * \ (pred_top + pred_bottom) w_intersect = torch.min(pred_left, target_left) + \ torch.min(pred_right, target_right) h_intersect = torch.min(pred_bottom, target_bottom) + \ torch.min(pred_top, target_top) g_w_intersect = torch.max(pred_left, target_left) + \ torch.max(pred_right, target_right) g_h_intersect = torch.max(pred_bottom, target_bottom) + \ torch.max(pred_top, target_top) ac_uion = g_w_intersect * g_h_intersect area_intersect = w_intersect * h_intersect area_union = target_aera + pred_aera - area_intersect ious = (area_intersect + 1.0) / (area_union + 1.0) gious = ious - (ac_uion - area_union) / ac_uion if self.loc_loss_type == 'iou': losses = -torch.log(ious) elif self.loc_loss_type == 'linear_iou': losses = 1 - ious elif self.loc_loss_type == 'giou': losses = 1 - gious else: raise NotImplementedError if weight is not None: return (losses * weight).sum() else: return losses.sum()