import torch from torch import nn as nn from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss from mmdet.models.builder import LOSSES def chamfer_distance(src, dst, src_weight=1.0, dst_weight=1.0, criterion_mode='l2', reduction='mean'): """Calculate Chamfer Distance of two sets. Args: src (tensor): Source set with shape [B, N, C] to calculate Chamfer Distance. dst (tensor): Destination set with shape [B, M, C] to calculate Chamfer Distance. src_weight (tensor or float): Weight of source loss. dst_weight (tensor or float): Weight of destination loss. criterion_mode (str): Criterion mode to calculate distance. The valid modes are smooth_l1, l1 or l2. reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. Returns: tuple: Source and Destination loss with indices. - loss_src (Tensor): The min distance from source to destination. - loss_dst (Tensor): The min distance from destination to source. - indices1 (Tensor): Index the min distance point for each point in source to destination. - indices2 (Tensor): Index the min distance point for each point in destination to source. """ if criterion_mode == 'smooth_l1': criterion = smooth_l1_loss elif criterion_mode == 'l1': criterion = l1_loss elif criterion_mode == 'l2': criterion = mse_loss else: raise NotImplementedError src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1) dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1) distance = criterion(src_expand, dst_expand, reduction='none').sum(-1) src2dst_distance, indices1 = torch.min(distance, dim=2) # (B,N) dst2src_distance, indices2 = torch.min(distance, dim=1) # (B,M) loss_src = (src2dst_distance * src_weight) loss_dst = (dst2src_distance * dst_weight) if reduction == 'sum': loss_src = torch.sum(loss_src) loss_dst = torch.sum(loss_dst) elif reduction == 'mean': loss_src = torch.mean(loss_src) loss_dst = torch.mean(loss_dst) elif reduction == 'none': pass else: raise NotImplementedError return loss_src, loss_dst, indices1, indices2 @LOSSES.register_module() class ChamferDistance(nn.Module): """Calculate Chamfer Distance of two sets. Args: mode (str): Criterion mode to calculate distance. The valid modes are smooth_l1, l1 or l2. reduction (str): Method to reduce losses. The valid reduction method are none, sum or mean. loss_src_weight (float): Weight of loss_source. loss_dst_weight (float): Weight of loss_target. """ def __init__(self, mode='l2', reduction='mean', loss_src_weight=1.0, loss_dst_weight=1.0): super(ChamferDistance, self).__init__() assert mode in ['smooth_l1', 'l1', 'l2'] assert reduction in ['none', 'sum', 'mean'] self.mode = mode self.reduction = reduction self.loss_src_weight = loss_src_weight self.loss_dst_weight = loss_dst_weight def forward(self, source, target, src_weight=1.0, dst_weight=1.0, reduction_override=None, return_indices=False, **kwargs): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_source, loss_target, indices1, indices2 = chamfer_distance( source, target, src_weight, dst_weight, self.mode, reduction) loss_source *= self.loss_src_weight loss_target *= self.loss_dst_weight if return_indices: return loss_source, loss_target, indices1, indices2 else: return loss_source, loss_target