cross_entropy_loss.py 2.99 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
import torch
Jiangmiao Pang's avatar
Jiangmiao Pang committed
2
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
3
import torch.nn.functional as F
Jiangmiao Pang's avatar
Jiangmiao Pang committed
4

Kai Chen's avatar
Kai Chen committed
5
from .utils import weight_reduce_loss, weighted_loss
Jiangmiao Pang's avatar
Jiangmiao Pang committed
6
7
from ..registry import LOSSES

Kai Chen's avatar
Kai Chen committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
cross_entropy = weighted_loss(F.cross_entropy)


def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1
    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)
    return bin_labels, bin_label_weights


def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None):
    if pred.dim() != label.dim():
        label, weight = _expand_binary_labels(label, weight, pred.size(-1))

32
    # weighted element-wise losses
Kai Chen's avatar
Kai Chen committed
33
34
35
36
    if weight is not None:
        weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), weight, reduction='none')
37
38
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
Kai Chen's avatar
Kai Chen committed
39
40
41
42
43
44
45
46
47
48
49
50
51

    return loss


def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None):
    # TODO: handle these two reserved arguments
    assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
        pred_slice, target, reduction='mean')[None]

Jiangmiao Pang's avatar
Jiangmiao Pang committed
52
53
54
55

@LOSSES.register_module
class CrossEntropyLoss(nn.Module):

Kai Chen's avatar
Kai Chen committed
56
57
58
59
60
    def __init__(self,
                 use_sigmoid=False,
                 use_mask=False,
                 reduction='mean',
                 loss_weight=1.0):
Jiangmiao Pang's avatar
Jiangmiao Pang committed
61
62
63
64
        super(CrossEntropyLoss, self).__init__()
        assert (use_sigmoid is False) or (use_mask is False)
        self.use_sigmoid = use_sigmoid
        self.use_mask = use_mask
Kai Chen's avatar
Kai Chen committed
65
        self.reduction = reduction
Jiangmiao Pang's avatar
Jiangmiao Pang committed
66
67
68
        self.loss_weight = loss_weight

        if self.use_sigmoid:
Kai Chen's avatar
Kai Chen committed
69
            self.cls_criterion = binary_cross_entropy
Jiangmiao Pang's avatar
Jiangmiao Pang committed
70
71
72
        elif self.use_mask:
            self.cls_criterion = mask_cross_entropy
        else:
Kai Chen's avatar
Kai Chen committed
73
            self.cls_criterion = cross_entropy
Jiangmiao Pang's avatar
Jiangmiao Pang committed
74

75
76
77
78
79
80
    def forward(self,
                cls_score,
                label,
                weight=None,
                avg_factor=None,
                reduction_override=None,
Kai Chen's avatar
Kai Chen committed
81
                **kwargs):
82
83
84
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
85
        loss_cls = self.loss_weight * self.cls_criterion(
Kai Chen's avatar
Kai Chen committed
86
87
88
            cls_score,
            label,
            weight,
89
            reduction=reduction,
Kai Chen's avatar
Kai Chen committed
90
91
            avg_factor=avg_factor,
            **kwargs)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
92
        return loss_cls