smooth_l1_loss.py 525 Bytes
Newer Older
Jiangmiao Pang's avatar
Jiangmiao Pang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
from mmdet.core import weighted_smoothl1

from ..registry import LOSSES


@LOSSES.register_module
class SmoothL1Loss(nn.Module):

    def __init__(self, beta=1.0, loss_weight=1.0):
        super(SmoothL1Loss, self).__init__()
        self.beta = beta
        self.loss_weight = loss_weight

    def forward(self, pred, target, weight, *args, **kwargs):
        loss_bbox = self.loss_weight * weighted_smoothl1(
            pred, target, weight, beta=self.beta, *args, **kwargs)
        return loss_bbox