chamfer_distance.py 5.37 KB
Newer Older
wuyuefeng's avatar
Votenet  
wuyuefeng committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
from torch import nn as nn
wuyuefeng's avatar
Votenet  
wuyuefeng committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss

from mmdet.models.builder import LOSSES


def chamfer_distance(src,
                     dst,
                     src_weight=1.0,
                     dst_weight=1.0,
                     criterion_mode='l2',
                     reduction='mean'):
    """Calculate Chamfer Distance of two sets.

    Args:
17
        src (torch.Tensor): Source set with shape [B, N, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
18
            calculate Chamfer Distance.
19
        dst (torch.Tensor): Destination set with shape [B, M, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
20
            calculate Chamfer Distance.
21
22
        src_weight (torch.Tensor or float): Weight of source loss.
        dst_weight (torch.Tensor or float): Weight of destination loss.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
23
24
25
        criterion_mode (str): Criterion mode to calculate distance.
            The valid modes are smooth_l1, l1 or l2.
        reduction (str): Method to reduce losses.
26
            The valid reduction method are 'none', 'sum' or 'mean'.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
27
28

    Returns:
zhangwenwei's avatar
zhangwenwei committed
29
30
31
        tuple: Source and Destination loss with the corresponding indices.

            - loss_src (torch.Tensor): The min distance \
32
                from source to destination.
zhangwenwei's avatar
zhangwenwei committed
33
            - loss_dst (torch.Tensor): The min distance \
34
                from destination to source.
zhangwenwei's avatar
zhangwenwei committed
35
            - indices1 (torch.Tensor): Index the min distance point \
36
                for each point in source to destination.
zhangwenwei's avatar
zhangwenwei committed
37
            - indices2 (torch.Tensor): Index the min distance point \
38
                for each point in destination to source.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    """

    if criterion_mode == 'smooth_l1':
        criterion = smooth_l1_loss
    elif criterion_mode == 'l1':
        criterion = l1_loss
    elif criterion_mode == 'l2':
        criterion = mse_loss
    else:
        raise NotImplementedError

    src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
    dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)

    distance = criterion(src_expand, dst_expand, reduction='none').sum(-1)
    src2dst_distance, indices1 = torch.min(distance, dim=2)  # (B,N)
    dst2src_distance, indices2 = torch.min(distance, dim=1)  # (B,M)

    loss_src = (src2dst_distance * src_weight)
    loss_dst = (dst2src_distance * dst_weight)

    if reduction == 'sum':
        loss_src = torch.sum(loss_src)
        loss_dst = torch.sum(loss_dst)
    elif reduction == 'mean':
        loss_src = torch.mean(loss_src)
        loss_dst = torch.mean(loss_dst)
    elif reduction == 'none':
        pass
    else:
        raise NotImplementedError

    return loss_src, loss_dst, indices1, indices2


@LOSSES.register_module()
class ChamferDistance(nn.Module):
    """Calculate Chamfer Distance of two sets.

    Args:
        mode (str): Criterion mode to calculate distance.
            The valid modes are smooth_l1, l1 or l2.
        reduction (str): Method to reduce losses.
            The valid reduction method are none, sum or mean.
        loss_src_weight (float): Weight of loss_source.
        loss_dst_weight (float): Weight of loss_target.
    """

    def __init__(self,
                 mode='l2',
                 reduction='mean',
                 loss_src_weight=1.0,
                 loss_dst_weight=1.0):
        super(ChamferDistance, self).__init__()

        assert mode in ['smooth_l1', 'l1', 'l2']
        assert reduction in ['none', 'sum', 'mean']
        self.mode = mode
        self.reduction = reduction
        self.loss_src_weight = loss_src_weight
        self.loss_dst_weight = loss_dst_weight

    def forward(self,
                source,
                target,
                src_weight=1.0,
                dst_weight=1.0,
                reduction_override=None,
                return_indices=False,
                **kwargs):
109
110
111
        """Forward function of loss calculation.

        Args:
112
            source (torch.Tensor): Source set with shape [B, N, C] to
113
                calculate Chamfer Distance.
114
            target (torch.Tensor): Destination set with shape [B, M, C] to
115
                calculate Chamfer Distance.
116
117
118
119
            src_weight (torch.Tensor | float, optional):
                Weight of source loss. Defaults to 1.0.
            dst_weight (torch.Tensor | float, optional):
                Weight of destination loss. Defaults to 1.0.
120
121
122
123
124
125
126
            reduction_override (str, optional): Method to reduce losses.
                The valid reduction method are 'none', 'sum' or 'mean'.
                Defaults to None.
            return_indices (bool, optional): Whether to return indices.
                Defaults to False.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
127
128
129
130
131
            tuple[torch.Tensor]: If ``return_indices=True``, return losses of \
                source and target with their corresponding indices in the \
                order of ``(loss_source, loss_target, indices1, indices2)``. \
                If ``return_indices=False``, return \
                ``(loss_source, loss_target)``.
132
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss_source, loss_target, indices1, indices2 = chamfer_distance(
            source, target, src_weight, dst_weight, self.mode, reduction)

        loss_source *= self.loss_src_weight
        loss_target *= self.loss_dst_weight

        if return_indices:
            return loss_source, loss_target, indices1, indices2
        else:
            return loss_source, loss_target