chamfer_distance.py 5.41 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
2
import torch
zhangwenwei's avatar
zhangwenwei committed
3
from torch import nn as nn
wuyuefeng's avatar
Votenet  
wuyuefeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss

from mmdet.models.builder import LOSSES


def chamfer_distance(src,
                     dst,
                     src_weight=1.0,
                     dst_weight=1.0,
                     criterion_mode='l2',
                     reduction='mean'):
    """Calculate Chamfer Distance of two sets.

    Args:
18
        src (torch.Tensor): Source set with shape [B, N, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
19
            calculate Chamfer Distance.
20
        dst (torch.Tensor): Destination set with shape [B, M, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
21
            calculate Chamfer Distance.
22
23
        src_weight (torch.Tensor or float): Weight of source loss.
        dst_weight (torch.Tensor or float): Weight of destination loss.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
24
25
26
        criterion_mode (str): Criterion mode to calculate distance.
            The valid modes are smooth_l1, l1 or l2.
        reduction (str): Method to reduce losses.
27
            The valid reduction method are 'none', 'sum' or 'mean'.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
28
29

    Returns:
zhangwenwei's avatar
zhangwenwei committed
30
31
32
        tuple: Source and Destination loss with the corresponding indices.

            - loss_src (torch.Tensor): The min distance \
33
                from source to destination.
zhangwenwei's avatar
zhangwenwei committed
34
            - loss_dst (torch.Tensor): The min distance \
35
                from destination to source.
zhangwenwei's avatar
zhangwenwei committed
36
            - indices1 (torch.Tensor): Index the min distance point \
37
                for each point in source to destination.
zhangwenwei's avatar
zhangwenwei committed
38
            - indices2 (torch.Tensor): Index the min distance point \
39
                for each point in destination to source.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    """

    if criterion_mode == 'smooth_l1':
        criterion = smooth_l1_loss
    elif criterion_mode == 'l1':
        criterion = l1_loss
    elif criterion_mode == 'l2':
        criterion = mse_loss
    else:
        raise NotImplementedError

    src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
    dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)

    distance = criterion(src_expand, dst_expand, reduction='none').sum(-1)
    src2dst_distance, indices1 = torch.min(distance, dim=2)  # (B,N)
    dst2src_distance, indices2 = torch.min(distance, dim=1)  # (B,M)

    loss_src = (src2dst_distance * src_weight)
    loss_dst = (dst2src_distance * dst_weight)

    if reduction == 'sum':
        loss_src = torch.sum(loss_src)
        loss_dst = torch.sum(loss_dst)
    elif reduction == 'mean':
        loss_src = torch.mean(loss_src)
        loss_dst = torch.mean(loss_dst)
    elif reduction == 'none':
        pass
    else:
        raise NotImplementedError

    return loss_src, loss_dst, indices1, indices2


@LOSSES.register_module()
class ChamferDistance(nn.Module):
    """Calculate Chamfer Distance of two sets.

    Args:
        mode (str): Criterion mode to calculate distance.
            The valid modes are smooth_l1, l1 or l2.
        reduction (str): Method to reduce losses.
            The valid reduction method are none, sum or mean.
        loss_src_weight (float): Weight of loss_source.
        loss_dst_weight (float): Weight of loss_target.
    """

    def __init__(self,
                 mode='l2',
                 reduction='mean',
                 loss_src_weight=1.0,
                 loss_dst_weight=1.0):
        super(ChamferDistance, self).__init__()

        assert mode in ['smooth_l1', 'l1', 'l2']
        assert reduction in ['none', 'sum', 'mean']
        self.mode = mode
        self.reduction = reduction
        self.loss_src_weight = loss_src_weight
        self.loss_dst_weight = loss_dst_weight

    def forward(self,
                source,
                target,
                src_weight=1.0,
                dst_weight=1.0,
                reduction_override=None,
                return_indices=False,
                **kwargs):
110
111
112
        """Forward function of loss calculation.

        Args:
113
            source (torch.Tensor): Source set with shape [B, N, C] to
114
                calculate Chamfer Distance.
115
            target (torch.Tensor): Destination set with shape [B, M, C] to
116
                calculate Chamfer Distance.
117
118
119
120
            src_weight (torch.Tensor | float, optional):
                Weight of source loss. Defaults to 1.0.
            dst_weight (torch.Tensor | float, optional):
                Weight of destination loss. Defaults to 1.0.
121
122
123
124
125
126
127
            reduction_override (str, optional): Method to reduce losses.
                The valid reduction method are 'none', 'sum' or 'mean'.
                Defaults to None.
            return_indices (bool, optional): Whether to return indices.
                Defaults to False.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
128
129
130
131
132
            tuple[torch.Tensor]: If ``return_indices=True``, return losses of \
                source and target with their corresponding indices in the \
                order of ``(loss_source, loss_target, indices1, indices2)``. \
                If ``return_indices=False``, return \
                ``(loss_source, loss_target)``.
133
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss_source, loss_target, indices1, indices2 = chamfer_distance(
            source, target, src_weight, dst_weight, self.mode, reduction)

        loss_source *= self.loss_src_weight
        loss_target *= self.loss_dst_weight

        if return_indices:
            return loss_source, loss_target, indices1, indices2
        else:
            return loss_source, loss_target