multibin_loss.py 3.67 KB
Newer Older
ChaimZhu's avatar
ChaimZhu committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Optional

ChaimZhu's avatar
ChaimZhu committed
4
import torch
5
from mmdet.models.losses.utils import weighted_loss
6
from torch import Tensor
ChaimZhu's avatar
ChaimZhu committed
7
8
9
from torch import nn as nn
from torch.nn import functional as F

10
from mmdet3d.registry import MODELS
ChaimZhu's avatar
ChaimZhu committed
11
12
13


@weighted_loss
14
15
16
def multibin_loss(pred_orientations: Tensor,
                  gt_orientations: Tensor,
                  num_dir_bins: int = 4) -> Tensor:
ChaimZhu's avatar
ChaimZhu committed
17
18
19
    """Multi-Bin Loss.

    Args:
20
        pred_orientations(Tensor): Predicted local vector
ChaimZhu's avatar
ChaimZhu committed
21
22
            orientation in [axis_cls, head_cls, sin, cos] format.
            shape (N, num_dir_bins * 4)
23
        gt_orientations(Tensor): Corresponding gt bboxes,
ChaimZhu's avatar
ChaimZhu committed
24
            shape (N, num_dir_bins * 2).
25
        num_dir_bins(int): Number of bins to encode
ChaimZhu's avatar
ChaimZhu committed
26
            direction angle.
27
            Defaults to 4.
ChaimZhu's avatar
ChaimZhu committed
28

29
30
    Returns:
        Tensor: Loss tensor.
ChaimZhu's avatar
ChaimZhu committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """
    cls_losses = 0
    reg_losses = 0
    reg_cnt = 0
    for i in range(num_dir_bins):
        # bin cls loss
        cls_ce_loss = F.cross_entropy(
            pred_orientations[:, (i * 2):(i * 2 + 2)],
            gt_orientations[:, i].long(),
            reduction='mean')
        # regression loss
        valid_mask_i = (gt_orientations[:, i] == 1)
        cls_losses += cls_ce_loss
        if valid_mask_i.sum() > 0:
            start = num_dir_bins * 2 + i * 2
            end = start + 2
            pred_offset = F.normalize(pred_orientations[valid_mask_i,
                                                        start:end])
            gt_offset_sin = torch.sin(gt_orientations[valid_mask_i,
                                                      num_dir_bins + i])
            gt_offset_cos = torch.cos(gt_orientations[valid_mask_i,
                                                      num_dir_bins + i])
            reg_loss = \
                F.l1_loss(pred_offset[:, 0], gt_offset_sin,
                          reduction='none') + \
                F.l1_loss(pred_offset[:, 1], gt_offset_cos,
                          reduction='none')

            reg_losses += reg_loss.sum()
            reg_cnt += valid_mask_i.sum()

        return cls_losses / num_dir_bins + reg_losses / reg_cnt


65
@MODELS.register_module()
ChaimZhu's avatar
ChaimZhu committed
66
67
68
69
class MultiBinLoss(nn.Module):
    """Multi-Bin Loss for orientation.

    Args:
70
        reduction (str): The method to reduce the loss.
ChaimZhu's avatar
ChaimZhu committed
71
            Options are 'none', 'mean' and 'sum'. Defaults to 'none'.
72
        loss_weight (float): The weight of loss. Defaults
ChaimZhu's avatar
ChaimZhu committed
73
74
75
            to 1.0.
    """

76
77
78
    def __init__(self,
                 reduction: str = 'none',
                 loss_weight: float = 1.0) -> None:
ChaimZhu's avatar
ChaimZhu committed
79
80
81
82
83
        super(MultiBinLoss, self).__init__()
        assert reduction in ['none', 'sum', 'mean']
        self.reduction = reduction
        self.loss_weight = loss_weight

84
85
86
87
88
    def forward(self,
                pred: Tensor,
                target: Tensor,
                num_dir_bins: int,
                reduction_override: Optional[str] = None) -> Tensor:
ChaimZhu's avatar
ChaimZhu committed
89
90
91
        """Forward function.

        Args:
92
93
            pred (Tensor): The prediction.
            target (Tensor): The learning target of the prediction.
ChaimZhu's avatar
ChaimZhu committed
94
95
96
97
            num_dir_bins (int): Number of bins to encode direction angle.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
98
99
100

        Returns:
            Tensor: Loss tensor.
ChaimZhu's avatar
ChaimZhu committed
101
102
103
104
105
106
107
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * multibin_loss(
            pred, target, num_dir_bins=num_dir_bins, reduction=reduction)
        return loss