Commit a4372e17 authored by huchen's avatar huchen
Browse files

Merge branch 'hepj_work' into 'main'

增加conformer代码

See merge request dcutoolkit/deeplearing/dlexamples_new!7
parents 7f99c1c3 142dcf29
from ..builder import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module()
class MaskScoringRCNN(TwoStageDetector):
"""Mask Scoring RCNN.
https://arxiv.org/abs/1903.00241
"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None):
super(MaskScoringRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class NASFCOS(SingleStageDetector):
"""NAS-FCOS: Fast Neural Architecture Search for Object Detection.
https://arxiv.org/abs/1906.0442
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(NASFCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class PAA(SingleStageDetector):
"""Implementation of `PAA <https://arxiv.org/pdf/2007.08103.pdf>`_."""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
from ..builder import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module()
class PointRend(TwoStageDetector):
"""PointRend: Image Segmentation as Rendering
This detector is the implementation of
`PointRend <https://arxiv.org/abs/1912.08193>`_.
"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None):
super(PointRend, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class RepPointsDetector(SingleStageDetector):
"""RepPoints: Point Set Representation for Object Detection.
This detector is the implementation of:
- RepPoints detector (https://arxiv.org/pdf/1904.11490)
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(RepPointsDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class RetinaNet(SingleStageDetector):
"""Implementation of `RetinaNet <https://arxiv.org/abs/1708.02002>`_"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
import mmcv
from mmcv.image import tensor2imgs
from mmdet.core import bbox_mapping
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
@DETECTORS.register_module()
class RPN(BaseDetector):
"""Implementation of Region Proposal Network."""
def __init__(self,
backbone,
neck,
rpn_head,
train_cfg,
test_cfg,
pretrained=None):
super(RPN, self).__init__()
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck) if neck is not None else None
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head.update(train_cfg=rpn_train_cfg)
rpn_head.update(test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(RPN, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
self.neck.init_weights()
self.rpn_head.init_weights()
def extract_feat(self, img):
"""Extract features.
Args:
img (torch.Tensor): Image tensor with shape (n, c, h ,w).
Returns:
list[torch.Tensor]: Multi-level features that may have
different resolutions.
"""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Dummy forward function."""
x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)
return rpn_outs
def forward_train(self,
img,
img_metas,
gt_bboxes=None,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
if (isinstance(self.train_cfg.rpn, dict)
and self.train_cfg.rpn.get('debug', False)):
self.rpn_head.debug_imgs = tensor2imgs(img)
x = self.extract_feat(img)
losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None,
gt_bboxes_ignore)
return losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[np.ndarray]: proposals
"""
x = self.extract_feat(img)
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
if rescale:
for proposals, meta in zip(proposal_list, img_metas):
proposals[:, :4] /= proposals.new_tensor(meta['scale_factor'])
return [proposal.cpu().numpy() for proposal in proposal_list]
def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[np.ndarray]: proposals
"""
proposal_list = self.rpn_head.aug_test_rpn(
self.extract_feats(imgs), img_metas)
if not rescale:
for proposals, img_meta in zip(proposal_list, img_metas[0]):
img_shape = img_meta['img_shape']
scale_factor = img_meta['scale_factor']
flip = img_meta['flip']
flip_direction = img_meta['flip_direction']
proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
scale_factor, flip,
flip_direction)
return [proposal.cpu().numpy() for proposal in proposal_list]
def show_result(self, data, result, dataset=None, top_k=20):
"""Show RPN proposals on the image.
Although we assume batch size is 1, this method supports arbitrary
batch size.
"""
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, result, top_k=top_k)
import torch
import torch.nn as nn
from mmdet.core import bbox2result
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
@DETECTORS.register_module()
class SingleStageDetector(BaseDetector):
"""Base class for single-stage detectors.
Single-stage detectors directly and densely predict bounding boxes on the
output features of the backbone+neck.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(SingleStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/analysis_tools/get_flops.py`
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
return losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
imgs (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
assert hasattr(self.bbox_head, 'aug_test'), \
f'{self.bbox_head.__class__.__name__}' \
' does not support test-time augmentation'
feats = self.extract_feats(imgs)
return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]
from ..builder import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module()
class SparseRCNN(TwoStageDetector):
r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with
Learnable Proposals <https://arxiv.org/abs/2011.12450>`_"""
def __init__(self, *args, **kwargs):
super(SparseRCNN, self).__init__(*args, **kwargs)
assert self.with_rpn, 'Sparse R-CNN do not support external proposals'
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""Forward function of SparseR-CNN in train stage.
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (List[Tensor], optional) : Segmentation masks for
each box. But we don't support it in this architecture.
proposals (List[Tensor], optional): override rpn proposals with
custom proposals. Use when `with_rpn` is False.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
assert proposals is None, 'Sparse R-CNN does not support' \
' external proposals'
assert gt_masks is None, 'Sparse R-CNN does not instance segmenntaion'
x = self.extract_feat(img)
proposal_boxes, proposal_features, imgs_whwh = \
self.rpn_head.forward_train(x, img_metas)
roi_losses = self.roi_head.forward_train(
x,
proposal_boxes,
proposal_features,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=gt_bboxes_ignore,
gt_masks=gt_masks,
imgs_whwh=imgs_whwh)
return roi_losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
x = self.extract_feat(img)
proposal_boxes, proposal_features, imgs_whwh = \
self.rpn_head.simple_test_rpn(x, img_metas)
bbox_results = self.roi_head.simple_test(
x,
proposal_boxes,
proposal_features,
img_metas,
imgs_whwh=imgs_whwh,
rescale=rescale)
return bbox_results
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/analysis_tools/get_flops.py`
"""
# backbone
x = self.extract_feat(img)
# rpn
num_imgs = len(img)
dummy_img_metas = [
dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)
]
proposal_boxes, proposal_features, imgs_whwh = \
self.rpn_head.simple_test_rpn(x, dummy_img_metas)
# roi_head
roi_outs = self.roi_head.forward_dummy(x, proposal_boxes,
proposal_features,
dummy_img_metas)
return roi_outs
from ..builder import DETECTORS
from .faster_rcnn import FasterRCNN
@DETECTORS.register_module()
class TridentFasterRCNN(FasterRCNN):
"""Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None):
super(TridentFasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
assert self.backbone.num_branch == self.roi_head.num_branch
assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
self.num_branch = self.backbone.num_branch
self.test_branch_idx = self.backbone.test_branch_idx
def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Test without augmentation."""
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
trident_img_metas = img_metas * num_branch
proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
else:
proposal_list = proposals
return self.roi_head.simple_test(
x, proposal_list, trident_img_metas, rescale=rescale)
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
x = self.extract_feats(imgs)
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
return self.roi_head.aug_test(
x, proposal_list, img_metas, rescale=rescale)
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
"""make copies of img and gts to fit multi-branch."""
trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
trident_gt_labels = tuple(gt_labels * self.num_branch)
trident_img_metas = tuple(img_metas * self.num_branch)
return super(TridentFasterRCNN,
self).forward_train(img, trident_img_metas,
trident_gt_bboxes, trident_gt_labels)
import torch
import torch.nn as nn
# from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
"""Base class for two-stage detectors.
Two-stage detectors typically consisting of a region proposal network and a
task-specific regression head.
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
self.roi_head = build_head(roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
@property
def with_rpn(self):
"""bool: whether the detector has RPN"""
return hasattr(self, 'rpn_head') and self.rpn_head is not None
@property
def with_roi_head(self):
"""bool: whether the detector has a RoI head"""
return hasattr(self, 'roi_head') and self.roi_head is not None
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_roi_head:
self.roi_head.init_weights(pretrained)
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/analysis_tools/get_flops.py`
"""
outs = ()
# backbone
x = self.extract_feat(img)
# rpn
if self.with_rpn:
rpn_outs = self.rpn_head(x)
outs = outs + (rpn_outs, )
proposals = torch.randn(1000, 4).to(img.device)
# roi_head
roi_outs = self.roi_head.forward_dummy(x, proposals)
outs = outs + (roi_outs, )
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.
proposals : override rpn proposals with custom proposals. Use when
`with_rpn` is False.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
else:
proposal_list = proposals
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
async def async_simple_test(self,
img,
img_meta,
proposals=None,
rescale=False):
"""Async test without augmentation."""
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:
proposal_list = await self.rpn_head.async_simple_test_rpn(
x, img_meta)
else:
proposal_list = proposals
return await self.roi_head.async_simple_test(
x, proposal_list, img_meta, rescale=rescale)
def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Test without augmentation."""
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
else:
proposal_list = proposals
return self.roi_head.simple_test(
x, proposal_list, img_metas, rescale=rescale)
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
x = self.extract_feats(imgs)
proposal_list = self.rpn_head.aug_test_rpn(x, img_metas)
return self.roi_head.aug_test(
x, proposal_list, img_metas, rescale=rescale)
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class VFNet(SingleStageDetector):
"""Implementation of `VarifocalNet
(VFNet).<https://arxiv.org/abs/2008.13367>`_"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(VFNet, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
import torch
from mmdet.core import bbox2result
from ..builder import DETECTORS, build_head
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class YOLACT(SingleStageDetector):
"""Implementation of `YOLACT <https://arxiv.org/abs/1904.02689>`_"""
def __init__(self,
backbone,
neck,
bbox_head,
segm_head,
mask_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(YOLACT, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
self.segm_head = build_head(segm_head)
self.mask_head = build_head(mask_head)
self.init_segm_mask_weights()
def init_segm_mask_weights(self):
"""Initialize weights of the YOLACT semg head and YOLACT mask head."""
self.segm_head.init_weights()
self.mask_head.init_weights()
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/analysis_tools/get_flops.py`
"""
raise NotImplementedError
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
"""
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# convert Bitmap mask or Polygon Mask to Tensor here
gt_masks = [
gt_mask.to_tensor(dtype=torch.uint8, device=img.device)
for gt_mask in gt_masks
]
x = self.extract_feat(img)
cls_score, bbox_pred, coeff_pred = self.bbox_head(x)
bbox_head_loss_inputs = (cls_score, bbox_pred) + (gt_bboxes, gt_labels,
img_metas)
losses, sampling_results = self.bbox_head.loss(
*bbox_head_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
segm_head_outs = self.segm_head(x[0])
loss_segm = self.segm_head.loss(segm_head_outs, gt_masks, gt_labels)
losses.update(loss_segm)
mask_pred = self.mask_head(x[0], coeff_pred, gt_bboxes, img_metas,
sampling_results)
loss_mask = self.mask_head.loss(mask_pred, gt_masks, gt_bboxes,
img_metas, sampling_results)
losses.update(loss_mask)
# check NaN and Inf
for loss_name in losses.keys():
assert torch.isfinite(torch.stack(losses[loss_name]))\
.all().item(), '{} becomes infinite or NaN!'\
.format(loss_name)
return losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation."""
x = self.extract_feat(img)
cls_score, bbox_pred, coeff_pred = self.bbox_head(x)
bbox_inputs = (cls_score, bbox_pred,
coeff_pred) + (img_metas, self.test_cfg, rescale)
det_bboxes, det_labels, det_coeffs = self.bbox_head.get_bboxes(
*bbox_inputs)
bbox_results = [
bbox2result(det_bbox, det_label, self.bbox_head.num_classes)
for det_bbox, det_label in zip(det_bboxes, det_labels)
]
num_imgs = len(img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [[[] for _ in range(self.mask_head.num_classes)]
for _ in range(num_imgs)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_preds = self.mask_head(x[0], det_coeffs, _bboxes, img_metas)
# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append(
[[] for _ in range(self.mask_head.num_classes)])
else:
segm_result = self.mask_head.get_seg_masks(
mask_preds[i], det_labels[i], img_metas[i], rescale)
segm_results.append(segm_result)
return list(zip(bbox_results, segm_results))
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations."""
raise NotImplementedError
# Copyright (c) 2019 Western Digital Corporation or its affiliates.
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class YOLOV3(SingleStageDetector):
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
from .accuracy import Accuracy, accuracy
from .ae_loss import AssociativeEmbeddingLoss
from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .gaussian_focal_loss import GaussianFocalLoss
from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
from .ghm_loss import GHMC, GHMR
from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss,
bounded_iou_loss, iou_loss)
from .mse_loss import MSELoss, mse_loss
from .pisa_loss import carl_loss, isr_p
from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
from .varifocal_loss import VarifocalLoss
__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 'GHMC',
'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
'VarifocalLoss'
]
import torch.nn as nn
def accuracy(pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == 2 and target.ndim == 1
assert pred.size(0) == target.size(0)
assert maxk <= pred.size(1), \
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
pred_value, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t() # transpose to shape (maxk, N)
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
class Accuracy(nn.Module):
def __init__(self, topk=(1, ), thresh=None):
"""Module to calculate the accuracy.
Args:
topk (tuple, optional): The criterion used to calculate the
accuracy. Defaults to (1,).
thresh (float, optional): If not None, predictions with scores
under this threshold are considered incorrect. Default to None.
"""
super().__init__()
self.topk = topk
self.thresh = thresh
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh)
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
def ae_loss_per_image(tl_preds, br_preds, match):
"""Associative Embedding Loss in one image.
Associative Embedding Loss including two parts: pull loss and push loss.
Pull loss makes embedding vectors from same object closer to each other.
Push loss distinguish embedding vector from different objects, and makes
the gap between them is large enough.
During computing, usually there are 3 cases:
- no object in image: both pull loss and push loss will be 0.
- one object in image: push loss will be 0 and pull loss is computed
by the two corner of the only object.
- more than one objects in image: pull loss is computed by corner pairs
from each object, push loss is computed by each object with all
other objects. We use confusion matrix with 0 in diagonal to
compute the push loss.
Args:
tl_preds (tensor): Embedding feature map of left-top corner.
br_preds (tensor): Embedding feature map of bottim-right corner.
match (list): Downsampled coordinates pair of each ground truth box.
"""
tl_list, br_list, me_list = [], [], []
if len(match) == 0: # no object in image
pull_loss = tl_preds.sum() * 0.
push_loss = tl_preds.sum() * 0.
else:
for m in match:
[tl_y, tl_x], [br_y, br_x] = m
tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
br_e = br_preds[:, br_y, br_x].view(-1, 1)
tl_list.append(tl_e)
br_list.append(br_e)
me_list.append((tl_e + br_e) / 2.0)
tl_list = torch.cat(tl_list)
br_list = torch.cat(br_list)
me_list = torch.cat(me_list)
assert tl_list.size() == br_list.size()
# N is object number in image, M is dimension of embedding vector
N, M = tl_list.size()
pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
pull_loss = pull_loss.sum() / N
margin = 1 # exp setting of CornerNet, details in section 3.3 of paper
# confusion matrix of push loss
conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
conf_weight = 1 - torch.eye(N).type_as(me_list)
conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
if N > 1: # more than one object in current image
push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
else:
push_loss = tl_preds.sum() * 0.
return pull_loss, push_loss
@LOSSES.register_module()
class AssociativeEmbeddingLoss(nn.Module):
"""Associative Embedding Loss.
More details can be found in
`Associative Embedding <https://arxiv.org/abs/1611.05424>`_ and
`CornerNet <https://arxiv.org/abs/1808.01244>`_ .
Code is modified from `kp_utils.py <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L180>`_ # noqa: E501
Args:
pull_weight (float): Loss weight for corners from same object.
push_weight (float): Loss weight for corners from different object.
"""
def __init__(self, pull_weight=0.25, push_weight=0.25):
super(AssociativeEmbeddingLoss, self).__init__()
self.pull_weight = pull_weight
self.push_weight = push_weight
def forward(self, pred, target, match):
"""Forward function."""
batch = pred.size(0)
pull_all, push_all = 0.0, 0.0
for i in range(batch):
pull, push = ae_loss_per_image(pred[i], target[i], match[i])
pull_all += self.pull_weight * pull
push_all += self.push_weight * push
return pull_all, push_all
import numpy as np
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def balanced_l1_loss(pred,
target,
beta=1.0,
alpha=0.5,
gamma=1.5,
reduction='mean'):
"""Calculate balanced L1 loss.
Please see the `Libra R-CNN <https://arxiv.org/pdf/1904.02701.pdf>`_
Args:
pred (torch.Tensor): The prediction with shape (N, 4).
target (torch.Tensor): The learning target of the prediction with
shape (N, 4).
beta (float): The loss is a piecewise function of prediction and target
and ``beta`` serves as a threshold for the difference between the
prediction and target. Defaults to 1.0.
alpha (float): The denominator ``alpha`` in the balanced L1 loss.
Defaults to 0.5.
gamma (float): The ``gamma`` in the balanced L1 loss.
Defaults to 1.5.
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert beta > 0
assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target)
b = np.e**(gamma / alpha) - 1
loss = torch.where(
diff < beta, alpha / b *
(b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
gamma * diff + gamma / b - alpha * beta)
return loss
@LOSSES.register_module()
class BalancedL1Loss(nn.Module):
"""Balanced L1 Loss.
arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
Args:
alpha (float): The denominator ``alpha`` in the balanced L1 loss.
Defaults to 0.5.
gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
beta (float, optional): The loss is a piecewise function of prediction
and target. ``beta`` serves as a threshold for the difference
between the prediction and target. Defaults to 1.0.
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
"""
def __init__(self,
alpha=0.5,
gamma=1.5,
beta=1.0,
reduction='mean',
loss_weight=1.0):
super(BalancedL1Loss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.beta = beta
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 4).
target (torch.Tensor): The learning target of the prediction with
shape (N, 4).
weight (torch.Tensor, optional): Sample-wise loss weight with
shape (N, ).
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * balanced_l1_loss(
pred,
target,
weight,
alpha=self.alpha,
gamma=self.gamma,
beta=self.beta,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
Returns:
torch.Tensor: The calculated loss
"""
# element-wise losses
loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(
(labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
if label_weights is None:
bin_label_weights = None
else:
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
Returns:
torch.Tensor: The calculated loss
"""
if pred.dim() != label.dim():
label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask'
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
Returns:
torch.Tensor: The calculated loss
"""
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(
self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from ..builder import LOSSES
from .utils import weight_reduce_loss
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A warpper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred.contiguous(), target, gamma, alpha, None,
'none')
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
super(FocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
if torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target,
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment