chamfer_distance.py 5.31 KB
Newer Older
wuyuefeng's avatar
Votenet  
wuyuefeng committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
from torch import nn as nn
wuyuefeng's avatar
Votenet  
wuyuefeng committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss

from mmdet.models.builder import LOSSES


def chamfer_distance(src,
                     dst,
                     src_weight=1.0,
                     dst_weight=1.0,
                     criterion_mode='l2',
                     reduction='mean'):
    """Calculate Chamfer Distance of two sets.

    Args:
17
        src (torch.Tensor): Source set with shape [B, N, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
18
            calculate Chamfer Distance.
19
        dst (torch.Tensor): Destination set with shape [B, M, C] to
wuyuefeng's avatar
Votenet  
wuyuefeng committed
20
            calculate Chamfer Distance.
21
22
        src_weight (torch.Tensor or float): Weight of source loss.
        dst_weight (torch.Tensor or float): Weight of destination loss.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
23
24
25
        criterion_mode (str): Criterion mode to calculate distance.
            The valid modes are smooth_l1, l1 or l2.
        reduction (str): Method to reduce losses.
26
            The valid reduction method are 'none', 'sum' or 'mean'.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
27
28
29

    Returns:
        tuple: Source and Destination loss with indices.
30
31
32
33
34
35
36
37
            - loss_src (torch.Tensor): The min distance
                from source to destination.
            - loss_dst (torch.Tensor): The min distance
                from destination to source.
            - indices1 (torch.Tensor): Index the min distance point
                for each point in source to destination.
            - indices2 (torch.Tensor): Index the min distance point
                for each point in destination to source.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    """

    if criterion_mode == 'smooth_l1':
        criterion = smooth_l1_loss
    elif criterion_mode == 'l1':
        criterion = l1_loss
    elif criterion_mode == 'l2':
        criterion = mse_loss
    else:
        raise NotImplementedError

    src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
    dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)

    distance = criterion(src_expand, dst_expand, reduction='none').sum(-1)
    src2dst_distance, indices1 = torch.min(distance, dim=2)  # (B,N)
    dst2src_distance, indices2 = torch.min(distance, dim=1)  # (B,M)

    loss_src = (src2dst_distance * src_weight)
    loss_dst = (dst2src_distance * dst_weight)

    if reduction == 'sum':
        loss_src = torch.sum(loss_src)
        loss_dst = torch.sum(loss_dst)
    elif reduction == 'mean':
        loss_src = torch.mean(loss_src)
        loss_dst = torch.mean(loss_dst)
    elif reduction == 'none':
        pass
    else:
        raise NotImplementedError

    return loss_src, loss_dst, indices1, indices2


@LOSSES.register_module()
class ChamferDistance(nn.Module):
    """Calculate Chamfer Distance of two sets.

    Args:
        mode (str): Criterion mode to calculate distance.
            The valid modes are smooth_l1, l1 or l2.
        reduction (str): Method to reduce losses.
            The valid reduction method are none, sum or mean.
        loss_src_weight (float): Weight of loss_source.
        loss_dst_weight (float): Weight of loss_target.
    """

    def __init__(self,
                 mode='l2',
                 reduction='mean',
                 loss_src_weight=1.0,
                 loss_dst_weight=1.0):
        super(ChamferDistance, self).__init__()

        assert mode in ['smooth_l1', 'l1', 'l2']
        assert reduction in ['none', 'sum', 'mean']
        self.mode = mode
        self.reduction = reduction
        self.loss_src_weight = loss_src_weight
        self.loss_dst_weight = loss_dst_weight

    def forward(self,
                source,
                target,
                src_weight=1.0,
                dst_weight=1.0,
                reduction_override=None,
                return_indices=False,
                **kwargs):
108
109
110
        """Forward function of loss calculation.

        Args:
111
            source (torch.Tensor): Source set with shape [B, N, C] to
112
                calculate Chamfer Distance.
113
            target (torch.Tensor): Destination set with shape [B, M, C] to
114
                calculate Chamfer Distance.
115
116
117
118
            src_weight (torch.Tensor | float, optional):
                Weight of source loss. Defaults to 1.0.
            dst_weight (torch.Tensor | float, optional):
                Weight of destination loss. Defaults to 1.0.
119
120
121
122
123
124
125
126
127
128
129
130
            reduction_override (str, optional): Method to reduce losses.
                The valid reduction method are 'none', 'sum' or 'mean'.
                Defaults to None.
            return_indices (bool, optional): Whether to return indices.
                Defaults to False.

        Returns:
            tuple[torch.Tensor]: If ``return_indices=True``, return losses of
                source and target with their corresponding indices in the order
                of (loss_source, loss_target, indices1, indices2). If
                ``return_indices=False``, return (loss_source, loss_target).
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        loss_source, loss_target, indices1, indices2 = chamfer_distance(
            source, target, src_weight, dst_weight, self.mode, reduction)

        loss_source *= self.loss_src_weight
        loss_target *= self.loss_dst_weight

        if return_indices:
            return loss_source, loss_target, indices1, indices2
        else:
            return loss_source, loss_target