cross_entropy_loss.py 3 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
import torch
Jiangmiao Pang's avatar
Jiangmiao Pang committed
2
import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
3
import torch.nn.functional as F
Jiangmiao Pang's avatar
Jiangmiao Pang committed
4

Kai Chen's avatar
Kai Chen committed
5
from .utils import weight_reduce_loss, weighted_loss
Jiangmiao Pang's avatar
Jiangmiao Pang committed
6
7
from ..registry import LOSSES

Kai Chen's avatar
Kai Chen committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
cross_entropy = weighted_loss(F.cross_entropy)


def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1
    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)
    return bin_labels, bin_label_weights


def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None):
    if pred.dim() != label.dim():
        label, weight = _expand_binary_labels(label, weight, pred.size(-1))

    # element-wise losses
    if weight is not None:
        weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), weight, reduction='none')
    # apply weights and do the reduction
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    return loss


def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None):
    # TODO: handle these two reserved arguments
    assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
        pred_slice, target, reduction='mean')[None]

Jiangmiao Pang's avatar
Jiangmiao Pang committed
53
54
55
56

@LOSSES.register_module
class CrossEntropyLoss(nn.Module):

Kai Chen's avatar
Kai Chen committed
57
58
59
60
61
    def __init__(self,
                 use_sigmoid=False,
                 use_mask=False,
                 reduction='mean',
                 loss_weight=1.0):
Jiangmiao Pang's avatar
Jiangmiao Pang committed
62
63
64
65
        super(CrossEntropyLoss, self).__init__()
        assert (use_sigmoid is False) or (use_mask is False)
        self.use_sigmoid = use_sigmoid
        self.use_mask = use_mask
Kai Chen's avatar
Kai Chen committed
66
        self.reduction = reduction
Jiangmiao Pang's avatar
Jiangmiao Pang committed
67
68
69
        self.loss_weight = loss_weight

        if self.use_sigmoid:
Kai Chen's avatar
Kai Chen committed
70
            self.cls_criterion = binary_cross_entropy
Jiangmiao Pang's avatar
Jiangmiao Pang committed
71
72
73
        elif self.use_mask:
            self.cls_criterion = mask_cross_entropy
        else:
Kai Chen's avatar
Kai Chen committed
74
            self.cls_criterion = cross_entropy
Jiangmiao Pang's avatar
Jiangmiao Pang committed
75

76
77
78
79
80
81
    def forward(self,
                cls_score,
                label,
                weight=None,
                avg_factor=None,
                reduction_override=None,
Kai Chen's avatar
Kai Chen committed
82
                **kwargs):
83
84
85
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
86
        loss_cls = self.loss_weight * self.cls_criterion(
Kai Chen's avatar
Kai Chen committed
87
88
89
            cls_score,
            label,
            weight,
90
            reduction=reduction,
Kai Chen's avatar
Kai Chen committed
91
92
            avg_factor=avg_factor,
            **kwargs)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
93
        return loss_cls