rotated_iou_loss.py 3.07 KB
Newer Older
1
2
3
4
5
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
from mmcv.ops import diff_iou_rotated_3d
6
from mmdet.models.losses.utils import weighted_loss
7
8
9
10
11
12
13
from torch import Tensor
from torch import nn as nn

from mmdet3d.registry import MODELS


@weighted_loss
14
def rotated_iou_3d_loss(pred: Tensor, target: Tensor) -> Tensor:
15
16
17
    """Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes.

    Note that predictions and targets are one-to-one corresponded.
18

19
    Args:
20
        pred (Tensor): Bbox predictions with shape [N, 7]
21
            (x, y, z, w, l, h, alpha).
22
        target (Tensor): Bbox targets (gt) with shape [N, 7]
23
            (x, y, z, w, l, h, alpha).
24

25
    Returns:
26
        Tensor: IoU loss between predictions and targets.
27
28
29
30
31
32
33
34
35
36
37
38
    """
    iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0),
                                       target.unsqueeze(0))[0]
    return iou_loss


@MODELS.register_module()
class RotatedIoU3DLoss(nn.Module):
    """Calculate the IoU loss (1-IoU) of rotated bounding boxes.

    Args:
        reduction (str): Method to reduce losses.
39
40
41
            The valid reduction method are 'none', 'sum' or 'mean'.
            Defaults to 'mean'.
        loss_weight (float): Weight of loss. Defaults to 1.0.
42
43
44
45
    """

    def __init__(self,
                 reduction: str = 'mean',
46
                 loss_weight: float = 1.0) -> None:
47
48
49
50
51
52
53
54
        super().__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred: Tensor,
                target: Tensor,
                weight: Optional[Tensor] = None,
55
                avg_factor: Optional[float] = None,
56
57
58
59
60
                reduction_override: Optional[str] = None,
                **kwargs) -> Tensor:
        """Forward function of loss calculation.

        Args:
61
            pred (Tensor): Bbox predictions with shape [..., 7]
62
                (x, y, z, w, l, h, alpha).
63
            target (Tensor): Bbox targets (gt) with shape [..., 7]
64
                (x, y, z, w, l, h, alpha).
65
            weight (Tensor, optional): Weight of loss.
66
                Defaults to None.
67
68
            avg_factor (float, optional): Average factor that is used to
                average the loss. Defaults to None.
69
70
71
72
73
            reduction_override (str, optional): Method to reduce losses.
                The valid reduction method are 'none', 'sum' or 'mean'.
                Defaults to None.

        Returns:
74
            Tensor: IoU loss between predictions and targets.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        """
        if weight is not None and not torch.any(weight > 0):
            return pred.sum() * weight.sum()  # 0
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if weight is not None and weight.dim() > 1:
            weight = weight.mean(-1)
        loss = self.loss_weight * rotated_iou_3d_loss(
            pred,
            target,
            weight,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs)

        return loss