import torch import torch.nn as nn import torch.nn.functional as F import numpy as np # from mmcv.runner import BaseModule, force_fp32 from torch.cuda.amp import autocast semantic_kitti_class_frequencies = np.array( [ 5.41773033e09, 1.57835390e07, 1.25136000e05, 1.18809000e05, 6.46799000e05, 8.21951000e05, 2.62978000e05, 2.83696000e05, 2.04750000e05, 6.16887030e07, 4.50296100e06, 4.48836500e07, 2.26992300e06, 5.68402180e07, 1.57196520e07, 1.58442623e08, 2.06162300e06, 3.69705220e07, 1.15198800e06, 3.34146000e05, ] ) kitti_class_names = [ "empty", "car", "bicycle", "motorcycle", "truck", "other-vehicle", "person", "bicyclist", "motorcyclist", "road", "parking", "sidewalk", "other-ground", "building", "fence", "vegetation", "trunk", "terrain", "pole", "traffic-sign", ] def inverse_sigmoid(x, sign='A'): x = x.to(torch.float32) while x >= 1-1e-5: x = x - 1e-5 while x< 1e-5: x = x + 1e-5 return -torch.log((1 / x) - 1) def KL_sep(p, target): """ KL divergence on nonzeros classes """ nonzeros = target != 0 nonzero_p = p[nonzeros] kl_term = F.kl_div(torch.log(nonzero_p), target[nonzeros], reduction="sum") return kl_term def geo_scal_loss(pred, ssc_target, ignore_index=255, non_empty_idx=0): # Get softmax probabilities pred = F.softmax(pred, dim=1) # Compute empty and nonempty probabilities empty_probs = pred[:, non_empty_idx] nonempty_probs = 1 - empty_probs # Remove unknown voxels mask = ssc_target != ignore_index nonempty_target = ssc_target != non_empty_idx nonempty_target = nonempty_target[mask].float() nonempty_probs = nonempty_probs[mask] empty_probs = empty_probs[mask] eps = 1e-5 intersection = (nonempty_target * nonempty_probs).sum() precision = intersection / (nonempty_probs.sum()+eps) recall = intersection / (nonempty_target.sum()+eps) spec = ((1 - nonempty_target) * (empty_probs)).sum() / ((1 - nonempty_target).sum()+eps) with autocast(False): return ( F.binary_cross_entropy_with_logits(inverse_sigmoid(precision, 'A'), torch.ones_like(precision)) + F.binary_cross_entropy_with_logits(inverse_sigmoid(recall, 'B'), torch.ones_like(recall)) + F.binary_cross_entropy_with_logits(inverse_sigmoid(spec, 'C'), torch.ones_like(spec)) ) def sem_scal_loss(pred_, ssc_target, ignore_index=255): # Get softmax probabilities with autocast(False): pred = F.softmax(pred_, dim=1) # (B, n_class, Dx, Dy, Dz) loss = 0 count = 0 mask = ssc_target != ignore_index n_classes = pred.shape[1] begin = 0 for i in range(begin, n_classes-1): # Get probability of class i p = pred[:, i] # (B, Dx, Dy, Dz) # Remove unknown voxels target_ori = ssc_target # (B, Dx, Dy, Dz) p = p[mask] target = ssc_target[mask] completion_target = torch.ones_like(target) completion_target[target != i] = 0 completion_target_ori = torch.ones_like(target_ori).float() completion_target_ori[target_ori != i] = 0 if torch.sum(completion_target) > 0: count += 1.0 nominator = torch.sum(p * completion_target) loss_class = 0 if torch.sum(p) > 0: precision = nominator / (torch.sum(p)+ 1e-5) loss_precision = F.binary_cross_entropy_with_logits( inverse_sigmoid(precision, 'D'), torch.ones_like(precision) ) loss_class += loss_precision if torch.sum(completion_target) > 0: recall = nominator / (torch.sum(completion_target) +1e-5) # loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall)) loss_recall = F.binary_cross_entropy_with_logits(inverse_sigmoid(recall, 'E'), torch.ones_like(recall)) loss_class += loss_recall if torch.sum(1 - completion_target) > 0: specificity = torch.sum((1 - p) * (1 - completion_target)) / ( torch.sum(1 - completion_target) + 1e-5 ) loss_specificity = F.binary_cross_entropy_with_logits( inverse_sigmoid(specificity, 'F'), torch.ones_like(specificity) ) loss_class += loss_specificity loss += loss_class # print(i, loss_class, loss_recall, loss_specificity) l = loss/count if torch.isnan(l): from IPython import embed embed() exit() return l def CE_ssc_loss(pred, target, class_weights=None, ignore_index=255): """ :param: prediction: the predicted tensor, must be [BS, C, ...] """ criterion = nn.CrossEntropyLoss( weight=class_weights, ignore_index=ignore_index, reduction="mean" ) # from IPython import embed # embed() # exit() with autocast(False): loss = criterion(pred, target.long()) return loss def vel_loss(pred, gt): with autocast(False): return F.l1_loss(pred, gt)