losses.py 1.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/losses/centernet_loss.py # noqa

import torch
from torch import nn

from mmdet3d.registry import MODELS


def _gather_feat(feat, ind, mask=None):
    dim = feat.size(2)
    ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat


def _transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()
    feat = feat.view(feat.size(0), -1, feat.size(3))
    feat = _gather_feat(feat, ind)
    return feat


@MODELS.register_module()
class FastFocalLoss(nn.Module):
    """Reimplemented focal loss, exactly the same as the CornerNet version.

    Faster and costs much less memory.
    """

    def __init__(self, focal_factor=2):
        super(FastFocalLoss, self).__init__()
        self.focal_factor = focal_factor

    def forward(self, out, target, ind, mask, cat):
        '''
        Args:
            out, target: B x C x H x W
            ind, mask: B x M
            cat (category id for peaks): B x M
        '''
        mask = mask.float()
        gt = torch.pow(1 - target, 4)
        neg_loss = torch.log(1 - out) * torch.pow(out, self.focal_factor) * gt
        neg_loss = neg_loss.sum()

        pos_pred_pix = _transpose_and_gather_feat(out, ind)  # B x M x C
        pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2))  # B x M
        num_pos = mask.sum()
        pos_loss = torch.log(pos_pred) * torch.pow(
            1 - pos_pred, self.focal_factor) * mask.unsqueeze(2)
        pos_loss = pos_loss.sum()
        if num_pos == 0:
            return -neg_loss
        return -(pos_loss + neg_loss) / num_pos