losses.py 4.87 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# TODO merge naive and weighted loss.
pangjm's avatar
pangjm committed
2
3
4
import torch
import torch.nn.functional as F

5
from ..bbox import bbox_overlaps
Cao Yuhang's avatar
Cao Yuhang committed
6
7
from ...ops import sigmoid_focal_loss

pangjm's avatar
pangjm committed
8

Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
def weighted_nll_loss(pred, label, weight, avg_factor=None):
    if avg_factor is None:
        avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
    raw = F.nll_loss(pred, label, reduction='none')
    return torch.sum(raw * weight)[None] / avg_factor
pangjm's avatar
pangjm committed
14
15


16
def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
Kai Chen's avatar
Kai Chen committed
17
18
19
    if avg_factor is None:
        avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
    raw = F.cross_entropy(pred, label, reduction='none')
yhcao6's avatar
rename  
yhcao6 committed
20
    if reduce:
yhcao6's avatar
yhcao6 committed
21
        return torch.sum(raw * weight)[None] / avg_factor
yhcao6's avatar
yhcao6 committed
22
    else:
yhcao6's avatar
yhcao6 committed
23
        return raw * weight / avg_factor
pangjm's avatar
pangjm committed
24
25


Kai Chen's avatar
Kai Chen committed
26
def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
27
28
    if pred.dim() != label.dim():
        label, weight = _expand_binary_labels(label, weight, pred.size(-1))
Kai Chen's avatar
Kai Chen committed
29
30
    if avg_factor is None:
        avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
pangjm's avatar
pangjm committed
31
32
    return F.binary_cross_entropy_with_logits(
        pred, label.float(), weight.float(),
Kai Chen's avatar
Kai Chen committed
33
        reduction='sum')[None] / avg_factor
pangjm's avatar
pangjm committed
34
35


Cao Yuhang's avatar
Cao Yuhang committed
36
37
38
39
40
41
def py_sigmoid_focal_loss(pred,
                          target,
                          weight,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean'):
pangjm's avatar
pangjm committed
42
    pred_sigmoid = pred.sigmoid()
43
    target = target.type_as(pred)
pangjm's avatar
pangjm committed
44
45
46
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
    weight = weight * pt.pow(gamma)
Kai Chen's avatar
Kai Chen committed
47
48
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * weight
Kai Chen's avatar
Kai Chen committed
49
50
51
52
53
54
55
56
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()
pangjm's avatar
pangjm committed
57
58
59
60
61
62
63


def weighted_sigmoid_focal_loss(pred,
                                target,
                                weight,
                                gamma=2.0,
                                alpha=0.25,
Kai Chen's avatar
Kai Chen committed
64
                                avg_factor=None,
pangjm's avatar
pangjm committed
65
                                num_classes=80):
Kai Chen's avatar
Kai Chen committed
66
67
    if avg_factor is None:
        avg_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6
Cao Yuhang's avatar
Cao Yuhang committed
68
69
70
    return torch.sum(
        sigmoid_focal_loss(pred, target, gamma, alpha, 'none') * weight.view(
            -1, 1))[None] / avg_factor
pangjm's avatar
pangjm committed
71
72
73
74
75
76
77


def mask_cross_entropy(pred, target, label):
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
Kai Chen's avatar
Kai Chen committed
78
        pred_slice, target, reduction='mean')[None]
pangjm's avatar
pangjm committed
79
80


Kai Chen's avatar
Kai Chen committed
81
def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
pangjm's avatar
pangjm committed
82
83
84
85
86
    assert beta > 0
    assert pred.size() == target.size() and target.numel() > 0
    diff = torch.abs(pred - target)
    loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
                       diff - 0.5 * beta)
Kai Chen's avatar
Kai Chen committed
87
88
89
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, mean:1, sum: 2
    if reduction_enum == 0:
Kai Chen's avatar
Kai Chen committed
90
        return loss
Kai Chen's avatar
Kai Chen committed
91
    elif reduction_enum == 1:
Kai Chen's avatar
Kai Chen committed
92
        return loss.sum() / pred.numel()
Kai Chen's avatar
Kai Chen committed
93
    elif reduction_enum == 2:
Kai Chen's avatar
Kai Chen committed
94
95
96
97
98
99
100
101
        return loss.sum()


def weighted_smoothl1(pred, target, weight, beta=1.0, avg_factor=None):
    if avg_factor is None:
        avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6
    loss = smooth_l1_loss(pred, target, beta, reduction='none')
    return torch.sum(loss * weight)[None] / avg_factor
pangjm's avatar
pangjm committed
102
103
104
105
106
107


def accuracy(pred, target, topk=1):
    if isinstance(topk, int):
        topk = (topk, )
        return_single = True
Kai Chen's avatar
Kai Chen committed
108
109
    else:
        return_single = False
pangjm's avatar
pangjm committed
110
111
112
113
114
115
116
117
118
119
120

    maxk = max(topk)
    _, pred_label = pred.topk(maxk, 1, True, True)
    pred_label = pred_label.t()
    correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / pred.size(0)))
    return res[0] if return_single else res
121
122
123
124
125
126
127
128
129
130


def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1
    bin_label_weights = label_weights.view(-1, 1).expand(
        label_weights.size(0), label_channels)
    return bin_labels, bin_label_weights
131
132
133
134
135
136
137
138
139
140
141
142
143


def iou_loss(pred_bboxes, target_bboxes, reduction='mean'):
    ious = bbox_overlaps(pred_bboxes, target_bboxes, is_aligned=True)
    loss = -ious.log()

    reduction_enum = F._Reduction.get_enum(reduction)
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()