dice_loss.py 3.9 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import nn
import paddle.nn.functional as F

from medicalseg.models.losses import flatten
from medicalseg.cvlibs import manager


@manager.LOSSES.add_component
class DiceLoss(nn.Layer):
    """
    Implements the dice loss function.

    Args:
        ignore_index (int64): Specifies a target value that is ignored
            and does not contribute to the input gradient. Default ``255``.
        smooth (float32): laplace smoothing,
            to smooth dice loss and accelerate convergence. following:
            https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895
    """

    def __init__(self, sigmoid_norm=True, weight=None):
        super(DiceLoss, self).__init__()
        self.weight = weight
        self.eps = 1e-5
        if sigmoid_norm:
            self.norm = nn.Sigmoid()
        else:
            self.norm = nn.Softmax(axis=1)

    def compute_per_channel_dice(self, input, target, epsilon=1e-6,
                                 weight=None):
        """
        Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
        Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.

        Args:
            input (torch.Tensor): NxCxSpatial input tensor
            target (torch.Tensor): NxCxSpatial target tensor
            epsilon (float): prevents division by zero
            weight (torch.Tensor): Cx1 tensor of weight per channel/class
        """

        # input and target shapes must match
        assert input.shape == target.shape, "'input' and 'target' must have the same shape but input is {} and target is {}".format(
            input.shape, target.shape)

        input = flatten(input)  # C, N*D*H*W
        target = flatten(target)
        target = paddle.cast(target, "float32")

        # compute per channel Dice Coefficient
        intersect = (input * target).sum(-1)  # sum at the spatial dimension
        if weight is not None:
            intersect = weight * intersect  # give different class different weight

        # Use standard dice: (input + target).sum(-1) or V-Net extension: (input^2 + target^2).sum(-1)
        denominator = (input * input).sum(-1) + (target * target).sum(-1)

        return 2 * (intersect / paddle.clip(denominator, min=epsilon))

    def forward(self, logits, labels):
        """
        logits: tensor of [B, C, D, H, W]
        labels: tensor of shape [B, D, H, W]
        """
        assert "int" in str(labels.dtype), print(
            "The label should be int but got {}".format(type(labels)))
        if len(logits.shape) == 4:
            logits = logits.unsqueeze(0)

        labels_one_hot = F.one_hot(
            labels, num_classes=logits.shape[1])  # [B, D, H, W, C]
        labels_one_hot = paddle.transpose(labels_one_hot,
                                          [0, 4, 1, 2, 3])  # [B, C, D, H, W]

        labels_one_hot = paddle.cast(labels_one_hot, dtype='float32')

        logits = self.norm(logits)  # softmax to sigmoid

        per_channel_dice = self.compute_per_channel_dice(
            logits, labels_one_hot, weight=self.weight)

        dice_loss = (1. - paddle.mean(per_channel_dice))
        per_channel_dice = per_channel_dice.detach().cpu(
        ).numpy()  # vnet variant dice

        return dice_loss, per_channel_dice