"docs/vscode:/vscode.git/clone" did not exist on "3cb437f5e988ea9d8d8368e2a6453d939d9fbb4b"
Unverified Commit bf9488d7 authored by roger-lcc's avatar roger-lcc Committed by GitHub
Browse files

[Doc]: Add type hints in losses (#2139)

* Add typehints for models/losses

* Add typehints for models/losses

* Add typehints for models/losses

* Add typehints for models/losses
parent d84e081b
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -8,16 +11,16 @@ from mmdet3d.structures import AxisAlignedBboxOverlaps3D ...@@ -8,16 +11,16 @@ from mmdet3d.structures import AxisAlignedBboxOverlaps3D
@weighted_loss @weighted_loss
def axis_aligned_iou_loss(pred, target): def axis_aligned_iou_loss(pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the IoU loss (1-IoU) of two set of axis aligned bounding """Calculate the IoU loss (1-IoU) of two set of axis aligned bounding
boxes. Note that predictions and targets are one-to-one corresponded. boxes. Note that predictions and targets are one-to-one corresponded.
Args: Args:
pred (torch.Tensor): Bbox predictions with shape [..., 3]. pred (Tensor): Bbox predictions with shape [..., 3].
target (torch.Tensor): Bbox targets (gt) with shape [..., 3]. target (Tensor): Bbox targets (gt) with shape [..., 3].
Returns: Returns:
torch.Tensor: IoU loss between predictions and targets. Tensor: IoU loss between predictions and targets.
""" """
axis_aligned_iou = AxisAlignedBboxOverlaps3D()( axis_aligned_iou = AxisAlignedBboxOverlaps3D()(
...@@ -32,38 +35,41 @@ class AxisAlignedIoULoss(nn.Module): ...@@ -32,38 +35,41 @@ class AxisAlignedIoULoss(nn.Module):
Args: Args:
reduction (str): Method to reduce losses. reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean. The valid reduction method are 'none', 'sum' or 'mean'.
loss_weight (float, optional): Weight of loss. Defaults to 1.0. Defaults to 'mean'.
loss_weight (float): Weight of loss. Defaults to 1.0.
""" """
def __init__(self, reduction='mean', loss_weight=1.0): def __init__(self,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(AxisAlignedIoULoss, self).__init__() super(AxisAlignedIoULoss, self).__init__()
assert reduction in ['none', 'sum', 'mean'] assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, def forward(self,
pred, pred: Tensor,
target, target: Tensor,
weight=None, weight: Optional[Tensor] = None,
avg_factor=None, avg_factor: Optional[float] = None,
reduction_override=None, reduction_override: Optional[str] = None,
**kwargs): **kwargs) -> Tensor:
"""Forward function of loss calculation. """Forward function of loss calculation.
Args: Args:
pred (torch.Tensor): Bbox predictions with shape [..., 3]. pred (Tensor): Bbox predictions with shape [..., 3].
target (torch.Tensor): Bbox targets (gt) with shape [..., 3]. target (Tensor): Bbox targets (gt) with shape [..., 3].
weight (torch.Tensor | float, optional): Weight of loss. weight (Tensor, optional): Weight of loss.
Defaults to None. Defaults to None.
avg_factor (int, optional): Average factor that is used to average avg_factor (float, optional): Average factor that is used to
the loss. Defaults to None. average the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses. reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'. The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None. Defaults to None.
Returns: Returns:
torch.Tensor: IoU loss between predictions and targets. Tensor: IoU loss between predictions and targets.
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
def chamfer_distance(src, def chamfer_distance(
dst, src: Tensor,
src_weight=1.0, dst: Tensor,
dst_weight=1.0, src_weight: Union[Tensor, float] = 1.0,
criterion_mode='l2', dst_weight: Union[Tensor, float] = 1.0,
reduction='mean'): criterion_mode: str = 'l2',
reduction: str = 'mean') -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Calculate Chamfer Distance of two sets. """Calculate Chamfer Distance of two sets.
Args: Args:
src (torch.Tensor): Source set with shape [B, N, C] to src (Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance. calculate Chamfer Distance.
dst (torch.Tensor): Destination set with shape [B, M, C] to dst (Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance. calculate Chamfer Distance.
src_weight (torch.Tensor or float): Weight of source loss. src_weight (Tensor or float): Weight of source loss. Defaults to 1.0.
dst_weight (torch.Tensor or float): Weight of destination loss. dst_weight (Tensor or float): Weight of destination loss.
Defaults to 1.0.
criterion_mode (str): Criterion mode to calculate distance. criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2. The valid modes are 'smooth_l1', 'l1' or 'l2'. Defaults to 'l2'.
reduction (str): Method to reduce losses. reduction (str): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'. The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to 'mean'.
Returns: Returns:
tuple: Source and Destination loss with the corresponding indices. tuple: Source and Destination loss with the corresponding indices.
- loss_src (torch.Tensor): The min distance - loss_src (Tensor): The min distance
from source to destination. from source to destination.
- loss_dst (torch.Tensor): The min distance - loss_dst (Tensor): The min distance
from destination to source. from destination to source.
- indices1 (torch.Tensor): Index the min distance point - indices1 (Tensor): Index the min distance point
for each point in source to destination. for each point in source to destination.
- indices2 (torch.Tensor): Index the min distance point - indices2 (Tensor): Index the min distance point
for each point in destination to source. for each point in destination to source.
""" """
...@@ -78,18 +84,19 @@ class ChamferDistance(nn.Module): ...@@ -78,18 +84,19 @@ class ChamferDistance(nn.Module):
Args: Args:
mode (str): Criterion mode to calculate distance. mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2. The valid modes are 'smooth_l1', 'l1' or 'l2'. Defaults to 'l2'.
reduction (str): Method to reduce losses. reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean. The valid reduction method are 'none', 'sum' or 'mean'.
loss_src_weight (float): Weight of loss_source. Defaults to 'mean'.
loss_dst_weight (float): Weight of loss_target. loss_src_weight (float): Weight of loss_source. Defaults to l.0.
loss_dst_weight (float): Weight of loss_target. Defaults to 1.0.
""" """
def __init__(self, def __init__(self,
mode='l2', mode: str = 'l2',
reduction='mean', reduction: str = 'mean',
loss_src_weight=1.0, loss_src_weight: float = 1.0,
loss_dst_weight=1.0): loss_dst_weight: float = 1.0) -> None:
super(ChamferDistance, self).__init__() super(ChamferDistance, self).__init__()
assert mode in ['smooth_l1', 'l1', 'l2'] assert mode in ['smooth_l1', 'l1', 'l2']
...@@ -99,33 +106,35 @@ class ChamferDistance(nn.Module): ...@@ -99,33 +106,35 @@ class ChamferDistance(nn.Module):
self.loss_src_weight = loss_src_weight self.loss_src_weight = loss_src_weight
self.loss_dst_weight = loss_dst_weight self.loss_dst_weight = loss_dst_weight
def forward(self, def forward(
source, self,
target, source: Tensor,
src_weight=1.0, target: Tensor,
dst_weight=1.0, src_weight: Union[Tensor, float] = 1.0,
reduction_override=None, dst_weight: Union[Tensor, float] = 1.0,
return_indices=False, reduction_override: Optional[str] = None,
**kwargs): return_indices: bool = False,
**kwargs
) -> Union[Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
"""Forward function of loss calculation. """Forward function of loss calculation.
Args: Args:
source (torch.Tensor): Source set with shape [B, N, C] to source (Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance. calculate Chamfer Distance.
target (torch.Tensor): Destination set with shape [B, M, C] to target (Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance. calculate Chamfer Distance.
src_weight (torch.Tensor | float, optional): src_weight (Tensor | float):
Weight of source loss. Defaults to 1.0. Weight of source loss. Defaults to 1.0.
dst_weight (torch.Tensor | float, optional): dst_weight (Tensor | float):
Weight of destination loss. Defaults to 1.0. Weight of destination loss. Defaults to 1.0.
reduction_override (str, optional): Method to reduce losses. reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'. The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None. Defaults to None.
return_indices (bool, optional): Whether to return indices. return_indices (bool): Whether to return indices.
Defaults to False. Defaults to False.
Returns: Returns:
tuple[torch.Tensor]: If ``return_indices=True``, return losses of tuple[Tensor]: If ``return_indices=True``, return losses of
source and target with their corresponding indices in the source and target with their corresponding indices in the
order of ``(loss_source, loss_target, indices1, indices2)``. order of ``(loss_source, loss_target, indices1, indices2)``.
If ``return_indices=False``, return If ``return_indices=False``, return
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -8,21 +11,23 @@ from mmdet3d.registry import MODELS ...@@ -8,21 +11,23 @@ from mmdet3d.registry import MODELS
@weighted_loss @weighted_loss
def multibin_loss(pred_orientations, gt_orientations, num_dir_bins=4): def multibin_loss(pred_orientations: Tensor,
gt_orientations: Tensor,
num_dir_bins: int = 4) -> Tensor:
"""Multi-Bin Loss. """Multi-Bin Loss.
Args: Args:
pred_orientations(torch.Tensor): Predicted local vector pred_orientations(Tensor): Predicted local vector
orientation in [axis_cls, head_cls, sin, cos] format. orientation in [axis_cls, head_cls, sin, cos] format.
shape (N, num_dir_bins * 4) shape (N, num_dir_bins * 4)
gt_orientations(torch.Tensor): Corresponding gt bboxes, gt_orientations(Tensor): Corresponding gt bboxes,
shape (N, num_dir_bins * 2). shape (N, num_dir_bins * 2).
num_dir_bins(int, optional): Number of bins to encode num_dir_bins(int): Number of bins to encode
direction angle. direction angle.
Defaults: 4. Defaults to 4.
Return: Returns:
torch.Tensor: Loss tensor. Tensor: Loss tensor.
""" """
cls_losses = 0 cls_losses = 0
reg_losses = 0 reg_losses = 0
...@@ -62,28 +67,37 @@ class MultiBinLoss(nn.Module): ...@@ -62,28 +67,37 @@ class MultiBinLoss(nn.Module):
"""Multi-Bin Loss for orientation. """Multi-Bin Loss for orientation.
Args: Args:
reduction (str, optional): The method to reduce the loss. reduction (str): The method to reduce the loss.
Options are 'none', 'mean' and 'sum'. Defaults to 'none'. Options are 'none', 'mean' and 'sum'. Defaults to 'none'.
loss_weight (float, optional): The weight of loss. Defaults loss_weight (float): The weight of loss. Defaults
to 1.0. to 1.0.
""" """
def __init__(self, reduction='none', loss_weight=1.0): def __init__(self,
reduction: str = 'none',
loss_weight: float = 1.0) -> None:
super(MultiBinLoss, self).__init__() super(MultiBinLoss, self).__init__()
assert reduction in ['none', 'sum', 'mean'] assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, num_dir_bins, reduction_override=None): def forward(self,
pred: Tensor,
target: Tensor,
num_dir_bins: int,
reduction_override: Optional[str] = None) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
pred (torch.Tensor): The prediction. pred (Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (Tensor): The learning target of the prediction.
num_dir_bins (int): Number of bins to encode direction angle. num_dir_bins (int): Number of bins to encode direction angle.
reduction_override (str, optional): The reduction method used to reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss. override the original reduction method of the loss.
Defaults to None. Defaults to None.
Returns:
Tensor: Loss tensor.
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch import torch
from mmdet.models.losses.utils import weight_reduce_loss from mmdet.models.losses.utils import weight_reduce_loss
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from ..layers import PAConv, PAConvCUDA from ..layers import PAConv, PAConvCUDA
def weight_correlation(conv): def weight_correlation(conv: nn.Module) -> Tensor:
"""Calculate correlations between kernel weights in Conv's weight bank as """Calculate correlations between kernel weights in Conv's weight bank as
regularization loss. The cosine similarity is used as metrics. regularization loss. The cosine similarity is used as metrics.
...@@ -16,7 +19,7 @@ def weight_correlation(conv): ...@@ -16,7 +19,7 @@ def weight_correlation(conv):
Currently we only support `PAConv` and `PAConvCUDA`. Currently we only support `PAConv` and `PAConvCUDA`.
Returns: Returns:
torch.Tensor: Correlations between each kernel weights in weight bank. Tensor: Correlations between each kernel weights in weight bank.
""" """
assert isinstance(conv, (PAConv, PAConvCUDA)), \ assert isinstance(conv, (PAConv, PAConvCUDA)), \
f'unsupported module type {type(conv)}' f'unsupported module type {type(conv)}'
...@@ -44,17 +47,18 @@ def weight_correlation(conv): ...@@ -44,17 +47,18 @@ def weight_correlation(conv):
return corr return corr
def paconv_regularization_loss(modules, reduction): def paconv_regularization_loss(modules: List[nn.Module],
reduction: str) -> Tensor:
"""Computes correlation loss of PAConv weight kernels as regularization. """Computes correlation loss of PAConv weight kernels as regularization.
Args: Args:
modules (List[nn.Module] | :obj:`generator`): modules (List[nn.Module] | :obj:`generator`):
A list or a python generator of torch.nn.Modules. A list or a python generator of torch.nn.Modules.
reduction (str): Method to reduce losses among PAConv modules. reduction (str): Method to reduce losses among PAConv modules.
The valid reduction method are none, sum or mean. The valid reduction method are 'none', 'sum' or 'mean'.
Returns: Returns:
torch.Tensor: Correlation loss of kernel weights. Tensor: Correlation loss of kernel weights.
""" """
corr_loss = [] corr_loss = []
for module in modules: for module in modules:
...@@ -77,17 +81,23 @@ class PAConvRegularizationLoss(nn.Module): ...@@ -77,17 +81,23 @@ class PAConvRegularizationLoss(nn.Module):
Args: Args:
reduction (str): Method to reduce losses. The reduction is performed reduction (str): Method to reduce losses. The reduction is performed
among all PAConv modules instead of prediction tensors. among all PAConv modules instead of prediction tensors.
The valid reduction method are none, sum or mean. The valid reduction method are 'none', 'sum' or 'mean'.
loss_weight (float, optional): Weight of loss. Defaults to 1.0. Defaults to 'mean'.
loss_weight (float): Weight of loss. Defaults to 1.0.
""" """
def __init__(self, reduction='mean', loss_weight=1.0): def __init__(self,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(PAConvRegularizationLoss, self).__init__() super(PAConvRegularizationLoss, self).__init__()
assert reduction in ['none', 'sum', 'mean'] assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, modules, reduction_override=None, **kwargs): def forward(self,
modules: List[nn.Module],
reduction_override: Optional[str] = None,
**kwargs) -> Tensor:
"""Forward function of loss calculation. """Forward function of loss calculation.
Args: Args:
...@@ -98,7 +108,7 @@ class PAConvRegularizationLoss(nn.Module): ...@@ -98,7 +108,7 @@ class PAConvRegularizationLoss(nn.Module):
Defaults to None. Defaults to None.
Returns: Returns:
torch.Tensor: Correlation loss of kernel weights. Tensor: Correlation loss of kernel weights.
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
......
...@@ -11,17 +11,19 @@ from mmdet3d.registry import MODELS ...@@ -11,17 +11,19 @@ from mmdet3d.registry import MODELS
@weighted_loss @weighted_loss
def rotated_iou_3d_loss(pred, target: Tensor) -> Tensor: def rotated_iou_3d_loss(pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes. """Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes.
Note that predictions and targets are one-to-one corresponded. Note that predictions and targets are one-to-one corresponded.
Args: Args:
pred (torch.Tensor): Bbox predictions with shape [N, 7] pred (Tensor): Bbox predictions with shape [N, 7]
(x, y, z, w, l, h, alpha). (x, y, z, w, l, h, alpha).
target (torch.Tensor): Bbox targets (gt) with shape [N, 7] target (Tensor): Bbox targets (gt) with shape [N, 7]
(x, y, z, w, l, h, alpha). (x, y, z, w, l, h, alpha).
Returns: Returns:
torch.Tensor: IoU loss between predictions and targets. Tensor: IoU loss between predictions and targets.
""" """
iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0), iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0),
target.unsqueeze(0))[0] target.unsqueeze(0))[0]
...@@ -34,13 +36,14 @@ class RotatedIoU3DLoss(nn.Module): ...@@ -34,13 +36,14 @@ class RotatedIoU3DLoss(nn.Module):
Args: Args:
reduction (str): Method to reduce losses. reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean. The valid reduction method are 'none', 'sum' or 'mean'.
loss_weight (float, optional): Weight of loss. Defaults to 1.0. Defaults to 'mean'.
loss_weight (float): Weight of loss. Defaults to 1.0.
""" """
def __init__(self, def __init__(self,
reduction: str = 'mean', reduction: str = 'mean',
loss_weight: Optional[float] = 1.0): loss_weight: float = 1.0) -> None:
super().__init__() super().__init__()
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
...@@ -49,26 +52,26 @@ class RotatedIoU3DLoss(nn.Module): ...@@ -49,26 +52,26 @@ class RotatedIoU3DLoss(nn.Module):
pred: Tensor, pred: Tensor,
target: Tensor, target: Tensor,
weight: Optional[Tensor] = None, weight: Optional[Tensor] = None,
avg_factor: Optional[int] = None, avg_factor: Optional[float] = None,
reduction_override: Optional[str] = None, reduction_override: Optional[str] = None,
**kwargs) -> Tensor: **kwargs) -> Tensor:
"""Forward function of loss calculation. """Forward function of loss calculation.
Args: Args:
pred (torch.Tensor): Bbox predictions with shape [..., 7] pred (Tensor): Bbox predictions with shape [..., 7]
(x, y, z, w, l, h, alpha). (x, y, z, w, l, h, alpha).
target (torch.Tensor): Bbox targets (gt) with shape [..., 7] target (Tensor): Bbox targets (gt) with shape [..., 7]
(x, y, z, w, l, h, alpha). (x, y, z, w, l, h, alpha).
weight (torch.Tensor | float, optional): Weight of loss. weight (Tensor, optional): Weight of loss.
Defaults to None. Defaults to None.
avg_factor (int, optional): Average factor that is used to average avg_factor (float, optional): Average factor that is used to
the loss. Defaults to None. average the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses. reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'. The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None. Defaults to None.
Returns: Returns:
torch.Tensor: IoU loss between predictions and targets. Tensor: IoU loss between predictions and targets.
""" """
if weight is not None and not torch.any(weight > 0): if weight is not None and not torch.any(weight > 0):
return pred.sum() * weight.sum() # 0 return pred.sum() * weight.sum() # 0
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
@weighted_loss @weighted_loss
def uncertain_smooth_l1_loss(pred, target, sigma, alpha=1.0, beta=1.0): def uncertain_smooth_l1_loss(pred: Tensor,
target: Tensor,
sigma: Tensor,
alpha: float = 1.0,
beta: float = 1.0) -> Tensor:
"""Smooth L1 loss with uncertainty. """Smooth L1 loss with uncertainty.
Args: Args:
pred (torch.Tensor): The prediction. pred (Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty. sigma (Tensor): The sigma for uncertainty.
alpha (float, optional): The coefficient of log(sigma). alpha (float): The coefficient of log(sigma).
Defaults to 1.0. Defaults to 1.0.
beta (float, optional): The threshold in the piecewise function. beta (float): The threshold in the piecewise function.
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
torch.Tensor: Calculated loss Tensor: Calculated loss
""" """
assert beta > 0 assert beta > 0
assert target.numel() > 0 assert target.numel() > 0
...@@ -36,18 +43,21 @@ def uncertain_smooth_l1_loss(pred, target, sigma, alpha=1.0, beta=1.0): ...@@ -36,18 +43,21 @@ def uncertain_smooth_l1_loss(pred, target, sigma, alpha=1.0, beta=1.0):
@weighted_loss @weighted_loss
def uncertain_l1_loss(pred, target, sigma, alpha=1.0): def uncertain_l1_loss(pred: Tensor,
target: Tensor,
sigma: Tensor,
alpha: float = 1.0) -> Tensor:
"""L1 loss with uncertainty. """L1 loss with uncertainty.
Args: Args:
pred (torch.Tensor): The prediction. pred (Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty. sigma (Tensor): The sigma for uncertainty.
alpha (float, optional): The coefficient of log(sigma). alpha (float): The coefficient of log(sigma).
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
torch.Tensor: Calculated loss Tensor: Calculated loss
""" """
assert target.numel() > 0 assert target.numel() > 0
assert pred.size() == target.size() == sigma.size(), 'The size of pred ' \ assert pred.size() == target.size() == sigma.size(), 'The size of pred ' \
...@@ -67,16 +77,20 @@ class UncertainSmoothL1Loss(nn.Module): ...@@ -67,16 +77,20 @@ class UncertainSmoothL1Loss(nn.Module):
and Semantics <https://arxiv.org/abs/1705.07115>`_ for more details. and Semantics <https://arxiv.org/abs/1705.07115>`_ for more details.
Args: Args:
alpha (float, optional): The coefficient of log(sigma). alpha (float): The coefficient of log(sigma).
Defaults to 1.0. Defaults to 1.0.
beta (float, optional): The threshold in the piecewise function. beta (float): The threshold in the piecewise function.
Defaults to 1.0. Defaults to 1.0.
reduction (str, optional): The method to reduce the loss. reduction (str): The method to reduce the loss.
Options are 'none', 'mean' and 'sum'. Defaults to 'mean'. Options are 'none', 'mean' and 'sum'. Defaults to 'mean'.
loss_weight (float, optional): The weight of loss. Defaults to 1.0 loss_weight (float): The weight of loss. Defaults to 1.0
""" """
def __init__(self, alpha=1.0, beta=1.0, reduction='mean', loss_weight=1.0): def __init__(self,
alpha: float = 1.0,
beta: float = 1.0,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(UncertainSmoothL1Loss, self).__init__() super(UncertainSmoothL1Loss, self).__init__()
assert reduction in ['none', 'sum', 'mean'] assert reduction in ['none', 'sum', 'mean']
self.alpha = alpha self.alpha = alpha
...@@ -85,26 +99,29 @@ class UncertainSmoothL1Loss(nn.Module): ...@@ -85,26 +99,29 @@ class UncertainSmoothL1Loss(nn.Module):
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, def forward(self,
pred, pred: Tensor,
target, target: Tensor,
sigma, sigma: Tensor,
weight=None, weight: Optional[Tensor] = None,
avg_factor=None, avg_factor: Optional[float] = None,
reduction_override=None, reduction_override: Optional[str] = None,
**kwargs): **kwargs) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
pred (torch.Tensor): The prediction. pred (Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty. sigma (Tensor): The sigma for uncertainty.
weight (torch.Tensor, optional): The weight of loss for each weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None. prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average avg_factor (float, optional): Average factor that is used to
the loss. Defaults to None. average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss. override the original reduction method of the loss.
Defaults to None. Defaults to None.
Returns:
Tensor: Calculated loss
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
...@@ -127,14 +144,17 @@ class UncertainL1Loss(nn.Module): ...@@ -127,14 +144,17 @@ class UncertainL1Loss(nn.Module):
"""L1 loss with uncertainty. """L1 loss with uncertainty.
Args: Args:
alpha (float, optional): The coefficient of log(sigma). alpha (float): The coefficient of log(sigma).
Defaults to 1.0. Defaults to 1.0.
reduction (str, optional): The method to reduce the loss. reduction (str): The method to reduce the loss.
Options are 'none', 'mean' and 'sum'. Defaults to 'mean'. Options are 'none', 'mean' and 'sum'. Defaults to 'mean'.
loss_weight (float, optional): The weight of loss. Defaults to 1.0. loss_weight (float): The weight of loss. Defaults to 1.0.
""" """
def __init__(self, alpha=1.0, reduction='mean', loss_weight=1.0): def __init__(self,
alpha: float = 1.0,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(UncertainL1Loss, self).__init__() super(UncertainL1Loss, self).__init__()
assert reduction in ['none', 'sum', 'mean'] assert reduction in ['none', 'sum', 'mean']
self.alpha = alpha self.alpha = alpha
...@@ -142,25 +162,28 @@ class UncertainL1Loss(nn.Module): ...@@ -142,25 +162,28 @@ class UncertainL1Loss(nn.Module):
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, def forward(self,
pred, pred: Tensor,
target, target: Tensor,
sigma, sigma: Tensor,
weight=None, weight: Optional[Tensor] = None,
avg_factor=None, avg_factor: Optional[float] = None,
reduction_override=None): reduction_override: Optional[str] = None) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
pred (torch.Tensor): The prediction. pred (Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty. sigma (Tensor): The sigma for uncertainty.
weight (torch.Tensor, optional): The weight of loss for each weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None. prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average avg_factor (float, optional): Average factor that is used to
the loss. Defaults to None. average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss. override the original reduction method of the loss.
Defaults to None. Defaults to None.
Returns:
Tensor: Calculated loss
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment