"website/vscode:/vscode.git/clone" did not exist on "51761b3af172b4fc54ce0a3abc302e203d2bf44a"
Commit 4895cd74 authored by jihanyang's avatar jihanyang
Browse files

add training losses

parent 635f1e94
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import box_utils
class SigmoidFocalClassificationLoss(nn.Module):
"""
Sigmoid focal cross entropy loss.
"""
def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
"""
Args:
gamma: Weighting parameter to balance loss for hard and easy examples.
alpha: Weighting parameter to balance loss for positive and negative examples.
"""
super(SigmoidFocalClassificationLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
@staticmethod
def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor):
""" PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits:
max(x, 0) - x * z + log(1 + exp(-abs(x))) in
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
Args:
input: (B, #anchors, #classes) float tensor.
Predicted logits for each class
target: (B, #anchors, #classes) float tensor.
One-hot encoded classification targets
Returns:
loss: (B, #anchors, #classes) float tensor.
Sigmoid cross entropy loss without reduction
"""
loss = torch.clamp(input, min=0) - input * target + \
torch.log1p(torch.exp(-torch.abs(input)))
return loss
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor):
"""
Args:
input: (B, #anchors, #classes) float tensor.
Predicted logits for each class
target: (B, #anchors, #classes) float tensor.
One-hot encoded classification targets
weights: (B, #anchors) float tensor.
Anchor-wise weights.
Returns:
weighted_loss: (B, #anchors, #classes) float tensor after weighting.
"""
pred_sigmoid = torch.sigmoid(input)
alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha)
pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid
focal_weight = alpha_weight * torch.pow(pt, self.gamma)
bce_loss = self.sigmoid_cross_entropy_with_logits(input, target)
loss = focal_weight * bce_loss
if weights.shape.__len__() == 2 or \
(weights.shape.__len__() == 1 and target.shape.__len__() == 2):
weights = weights.unsqueeze(-1)
assert weights.shape.__len__() == loss.shape.__len__()
return loss * weights
class WeightedSmoothL1Loss(nn.Module):
"""
Code-wise Weighted Smooth L1 Loss modified based on fvcore.nn.smooth_l1_loss
https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/smooth_l1_loss.py
| 0.5 * x ** 2 / beta if abs(x) < beta
smoothl1(x) = |
| abs(x) - 0.5 * beta otherwise,
where x = input - target.
"""
def __init__(self, beta: float = 1.0 / 9.0, code_weights: list = None):
"""
Args:
beta: Scalar float.
L1 to L2 change point.
For beta values < 1e-5, L1 loss is computed.
code_weights: (#codes) float list if not None.
Code-wise weights.
"""
super(WeightedSmoothL1Loss, self).__init__()
self.beta = beta
if code_weights is not None:
self.code_weights = np.array(code_weights, dtype=np.float32)
self.code_weights = torch.from_numpy(self.code_weights).cuda()
@staticmethod
def smooth_l1_loss(diff, beta):
if beta < 1e-5:
loss = torch.abs(diff)
else:
n = torch.abs(diff)
loss = torch.where(n < beta, 0.5 * n ** 2 / beta, n - 0.5 * beta)
return loss
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None):
"""
Args:
input: (B, #anchors, #codes) float tensor.
Ecoded predicted locations of objects.
target: (B, #anchors, #codes) float tensor.
Regression targets.
weights: (B, #anchors) float tensor if not None.
Returns:
loss: (B, #anchors) float tensor.
Weighted smooth l1 loss without reduction.
"""
diff = input - target
# code-wise weighting
if self.code_weights is not None:
diff = diff * self.code_weights.view(1, 1, -1)
loss = self.smooth_l1_loss(diff, self.beta)
# anchor-wise weighting
if weights is not None:
assert weights.shape[0] == loss.shape[0] and weights.shape[1] == loss.shape[1]
loss = loss * weights.unsqueeze(-1)
return loss
class WeightedCrossEntropyLoss(nn.Module):
"""
Transform input to fit the fomation of PyTorch offical cross entropy loss
with anchor-wise weighting.
"""
def __init__(self):
super(WeightedCrossEntropyLoss, self).__init__()
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor):
"""
Args:
input: (B, #anchors, #classes) float tensor.
Predited logits for each class.
target: (B, #anchors, #classes) float tensor.
One-hot classification targets.
weights: (B, #anchors) float tensor.
Anchor-wise weights.
Returns:
loss: (B, #anchors) float tensor.
Weighted cross entropy loss without reduction
"""
input = input.permute(0, 2, 1)
target = target.argmax(dim=-1)
loss = F.cross_entropy(input, target, reduction='none') * weights
return loss
def get_corner_loss_lidar(pred_bbox3d: torch.Tensor, gt_bbox3d: torch.Tensor):
"""
Args:
pred_bbox3d: (N, 7) float Tensor.
gt_bbox3d: (N, 7) float Tensor.
Returns:
corner_loss: (N) float Tensor.
"""
assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0]
pred_box_corners = box_utils.boxes_to_corners_3d(pred_bbox3d)
gt_box_corners = box_utils.boxes_to_corners_3d(gt_bbox3d)
gt_bbox3d_flip = gt_bbox3d.clone()
gt_bbox3d_flip[:, 6] += np.pi
gt_box_corners_flip = box_utils.boxes_to_corners_3d(gt_bbox3d_flip)
# (N, 8)
corner_dist = torch.min(torch.norm(pred_box_corners - gt_box_corners, dim=2),
torch.norm(pred_box_corners - gt_box_corners_flip, dim=2))
# (N, 8)
corner_loss = WeightedSmoothL1Loss.smooth_l1_loss(corner_dist, beta=1.0)
return corner_loss.mean(dim=1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment