"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "2c88db907f6738472084185e0ec0c6f6cae20e89"
Commit d0489cf6 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support WeightedL1Loss in anchor_head_template

parent 04a73eda
...@@ -74,9 +74,11 @@ class AnchorHeadTemplate(nn.Module): ...@@ -74,9 +74,11 @@ class AnchorHeadTemplate(nn.Module):
'cls_loss_func', 'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0) loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
) )
reg_loss_name = 'WeightedSmoothL1Loss' if losses_cfg.get('REG_LOSS_TYPE', None) is None \
else losses_cfg.REG_LOSS_TYPE
self.add_module( self.add_module(
'reg_loss_func', 'reg_loss_func',
loss_utils.WeightedSmoothL1Loss(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights']) getattr(loss_utils, reg_loss_name)(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
) )
self.add_module( self.add_module(
'dir_loss_func', 'dir_loss_func',
......
...@@ -135,6 +135,48 @@ class WeightedSmoothL1Loss(nn.Module): ...@@ -135,6 +135,48 @@ class WeightedSmoothL1Loss(nn.Module):
return loss return loss
class WeightedL1Loss(nn.Module):
def __init__(self, code_weights: list = None):
"""
Args:
code_weights: (#codes) float list if not None.
Code-wise weights.
"""
super(WeightedL1Loss, self).__init__()
if code_weights is not None:
self.code_weights = np.array(code_weights, dtype=np.float32)
self.code_weights = torch.from_numpy(self.code_weights).cuda()
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None):
"""
Args:
input: (B, #anchors, #codes) float tensor.
Ecoded predicted locations of objects.
target: (B, #anchors, #codes) float tensor.
Regression targets.
weights: (B, #anchors) float tensor if not None.
Returns:
loss: (B, #anchors) float tensor.
Weighted smooth l1 loss without reduction.
"""
target = torch.where(torch.isnan(target), input, target) # ignore nan targets
diff = input - target
# code-wise weighting
if self.code_weights is not None:
diff = diff * self.code_weights.view(1, 1, -1)
loss = torch.abs(diff)
# anchor-wise weighting
if weights is not None:
assert weights.shape[0] == loss.shape[0] and weights.shape[1] == loss.shape[1]
loss = loss * weights.unsqueeze(-1)
return loss
class WeightedCrossEntropyLoss(nn.Module): class WeightedCrossEntropyLoss(nn.Module):
""" """
Transform input to fit the fomation of PyTorch offical cross entropy loss Transform input to fit the fomation of PyTorch offical cross entropy loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment