chamfer_distance.py 5.76 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Optional, Tuple, Union

wuyuefeng's avatar
Votenet  
wuyuefeng committed
4
import torch
5
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
6
from torch import nn as nn
wuyuefeng's avatar
Votenet  
wuyuefeng committed
7
8
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss

9
from mmdet3d.registry import MODELS
wuyuefeng's avatar
Votenet  
wuyuefeng committed
10
11


12
13
14
15
16
17
18
def chamfer_distance(
        src: Tensor,
        dst: Tensor,
        src_weight: Union[Tensor, float] = 1.0,
        dst_weight: Union[Tensor, float] = 1.0,
        criterion_mode: str = 'l2',
        reduction: str = 'mean') -> Tuple[Tensor, Tensor, Tensor, Tensor]:
wuyuefeng's avatar
Votenet  
wuyuefeng committed
19
20
21
    """Calculate Chamfer Distance of two sets.

    Args:
22
        src (Tensor): Source set with shape [B, N, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
23
            calculate Chamfer Distance.
24
        dst (Tensor): Destination set with shape [B, M, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
25
            calculate Chamfer Distance.
26
27
28
        src_weight (Tensor or float): Weight of source loss. Defaults to 1.0.
        dst_weight (Tensor or float): Weight of destination loss.
            Defaults to 1.0.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
29
        criterion_mode (str): Criterion mode to calculate distance.
30
            The valid modes are 'smooth_l1', 'l1' or 'l2'. Defaults to 'l2'.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
31
        reduction (str): Method to reduce losses.
32
            The valid reduction method are 'none', 'sum' or 'mean'.
33
            Defaults to 'mean'.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
34
35

    Returns:
zhangwenwei's avatar
zhangwenwei committed
36
37
        tuple: Source and Destination loss with the corresponding indices.

38
39
40
41
42
43
44
45
            - loss_src (Tensor): The min distance
              from source to destination.
            - loss_dst (Tensor): The min distance
              from destination to source.
            - indices1 (Tensor): Index the min distance point
              for each point in source to destination.
            - indices2 (Tensor): Index the min distance point
              for each point in destination to source.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """

    if criterion_mode == 'smooth_l1':
        criterion = smooth_l1_loss
    elif criterion_mode == 'l1':
        criterion = l1_loss
    elif criterion_mode == 'l2':
        criterion = mse_loss
    else:
        raise NotImplementedError

    src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
    dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)

    distance = criterion(src_expand, dst_expand, reduction='none').sum(-1)
    src2dst_distance, indices1 = torch.min(distance, dim=2)  # (B,N)
    dst2src_distance, indices2 = torch.min(distance, dim=1)  # (B,M)

    loss_src = (src2dst_distance * src_weight)
    loss_dst = (dst2src_distance * dst_weight)

    if reduction == 'sum':
        loss_src = torch.sum(loss_src)
        loss_dst = torch.sum(loss_dst)
    elif reduction == 'mean':
        loss_src = torch.mean(loss_src)
        loss_dst = torch.mean(loss_dst)
    elif reduction == 'none':
        pass
    else:
        raise NotImplementedError

    return loss_src, loss_dst, indices1, indices2


81
@MODELS.register_module()
wuyuefeng's avatar
Votenet  
wuyuefeng committed
82
83
84
85
86
class ChamferDistance(nn.Module):
    """Calculate Chamfer Distance of two sets.

    Args:
        mode (str): Criterion mode to calculate distance.
87
            The valid modes are 'smooth_l1', 'l1' or 'l2'. Defaults to 'l2'.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
88
        reduction (str): Method to reduce losses.
89
90
91
92
            The valid reduction method are 'none', 'sum' or 'mean'.
            Defaults to 'mean'.
        loss_src_weight (float): Weight of loss_source. Defaults to l.0.
        loss_dst_weight (float): Weight of loss_target. Defaults to 1.0.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
93
94
95
    """

    def __init__(self,
96
97
98
99
                 mode: str = 'l2',
                 reduction: str = 'mean',
                 loss_src_weight: float = 1.0,
                 loss_dst_weight: float = 1.0) -> None:
wuyuefeng's avatar
Votenet  
wuyuefeng committed
100
101
102
103
104
105
106
107
108
        super(ChamferDistance, self).__init__()

        assert mode in ['smooth_l1', 'l1', 'l2']
        assert reduction in ['none', 'sum', 'mean']
        self.mode = mode
        self.reduction = reduction
        self.loss_src_weight = loss_src_weight
        self.loss_dst_weight = loss_dst_weight

109
110
111
112
113
114
115
116
117
118
    def forward(
        self,
        source: Tensor,
        target: Tensor,
        src_weight: Union[Tensor, float] = 1.0,
        dst_weight: Union[Tensor, float] = 1.0,
        reduction_override: Optional[str] = None,
        return_indices: bool = False,
        **kwargs
    ) -> Union[Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
119
120
121
        """Forward function of loss calculation.

        Args:
122
            source (Tensor): Source set with shape [B, N, C] to
123
                calculate Chamfer Distance.
124
            target (Tensor): Destination set with shape [B, M, C] to
125
                calculate Chamfer Distance.
126
            src_weight (Tensor | float):
127
                Weight of source loss. Defaults to 1.0.
128
            dst_weight (Tensor | float):
129
                Weight of destination loss. Defaults to 1.0.
130
131
132
            reduction_override (str, optional): Method to reduce losses.
                The valid reduction method are 'none', 'sum' or 'mean'.
                Defaults to None.
133
            return_indices (bool): Whether to return indices.
134
135
136
                Defaults to False.

        Returns:
137
            tuple[Tensor]: If ``return_indices=True``, return losses of
138
139
140
                source and target with their corresponding indices in the
                order of ``(loss_source, loss_target, indices1, indices2)``.
                If ``return_indices=False``, return
zhangwenwei's avatar
zhangwenwei committed
141
                ``(loss_source, loss_target)``.
142
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss_source, loss_target, indices1, indices2 = chamfer_distance(
            source, target, src_weight, dst_weight, self.mode, reduction)

        loss_source *= self.loss_src_weight
        loss_target *= self.loss_dst_weight

        if return_indices:
            return loss_source, loss_target, indices1, indices2
        else:
            return loss_source, loss_target