Commit 1e35964c authored by Kai Chen's avatar Kai Chen
Browse files

adjust the structure of detectors

parent 678f9334
from .detectors import Detector
from .detectors import *
from .builder import *
import math
import torch.nn as nn
import torch.utils.checkpoint as cp
from torchpack import load_checkpoint
from mmcv.torchpack import load_checkpoint
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
......
......@@ -85,7 +85,7 @@ class BBoxHead(nn.Module):
bbox_pred,
bbox_targets,
bbox_weights,
ave_factor=bbox_targets.size(0))
avg_factor=bbox_targets.size(0))
return losses
def get_det_bboxes(self,
......
import mmcv
from mmcv import torchpack
from mmcv import torchpack as tp
from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
mask_heads)
mask_heads, detectors)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head'
'build_bbox_head', 'build_mask_head', 'build_detector'
]
def _build_module(cfg, parrent=None):
return cfg if isinstance(cfg, nn.Module) else torchpack.obj_from_dict(
cfg, parrent)
def _build_module(cfg, parrent=None, default_args=None):
return cfg if isinstance(cfg, nn.Module) else tp.obj_from_dict(
cfg, parrent, default_args)
def build(cfg, parrent=None):
def build(cfg, parrent=None, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent) for cfg_ in cfg]
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return _build_module(cfg, parrent)
return _build_module(cfg, parrent, default_args)
def build_backbone(cfg):
......@@ -46,3 +45,7 @@ def build_bbox_head(cfg):
def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .detector import Detector
from .base import BaseDetector
from .rpn import RPN
__all__ = ['BaseDetector', 'RPN']
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
class BaseDetector(nn.Module):
"""Base class for detectors"""
__metaclass__ = ABCMeta
def __init__(self):
super(BaseDetector, self).__init__()
@abstractmethod
def init_weights(self):
pass
@abstractmethod
def extract_feat(self, imgs):
pass
def extract_feats(self, imgs):
if isinstance(imgs, torch.Tensor):
return self.extract_feat(imgs)
elif isinstance(imgs, list):
for img in imgs:
yield self.extract_feat(img)
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
pass
def forward_test(self, imgs, img_metas, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
def forward(self, img, img_meta, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
......@@ -8,6 +8,7 @@ from mmdet.core import (bbox2roi, bbox_mapping, split_combined_gt_polys,
class Detector(nn.Module):
def __init__(self,
backbone,
neck=None,
......
import mmcv
from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector
from .testing_mixins import RPNTestMixin
from .. import builder
class RPN(BaseDetector, RPNTestMixin):
def __init__(self,
backbone,
neck,
rpn_head,
train_cfg,
test_cfg,
pretrained=None):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
if pretrained is not None:
print('load model from: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None:
self.neck.init_weights()
self.rpn_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.neck is not None:
x = self.neck(x)
return x
def forward_train(self, img, img_meta, gt_bboxes=None):
if self.train_cfg.rpn.get('debug', False):
self.rpn_head.debug_imgs = tensor2imgs(img)
x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
losses = self.rpn_head.loss(*rpn_loss_inputs)
return losses
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
proposal_list = self.simple_test_rpn(x, img_meta, self.test_cfg.rpn)
if rescale:
for proposals, meta in zip(proposal_list, img_meta):
proposals[:, :4] /= meta['scale_factor']
# TODO: remove this restriction
return proposal_list[0].cpu().numpy()
def aug_test(self, imgs, img_metas, rescale=False):
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
if not rescale:
for proposals, img_meta in zip(proposal_list, img_metas[0]):
img_shape = img_meta['img_shape']
scale_factor = img_meta['scale_factor']
flip = img_meta['flip']
proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
scale_factor, flip)
# TODO: remove this restriction
return proposal_list[0].cpu().numpy()
def show_result(self, data, result, img_norm_cfg):
"""Show RPN proposals on the image.
Although we assume batch size is 1, this method supports arbitrary
batch size.
"""
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, result, top_k=20)
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals,
merge_aug_bboxes, merge_aug_masks, multiclass_nms)
class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
imgs_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(imgs_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, img_meta, rpn_test_cfg)
for proposals, img_meta in zip(aug_proposals, img_metas)
]
return merged_proposals
class BBoxTestMixin(object):
def simple_test_bboxes(self,
x,
img_meta,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
return det_bboxes, det_labels
def aug_test_bboxes(self, feats, img_metas, proposals, rcnn_test_cfg):
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
proposals = bbox_mapping(proposals[:, :4], img_shape, scale_factor,
flip)
rois = bbox2roi([proposals])
# recompute feature maps to save GPU memory
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
bboxes, scores = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
rescale=False,
nms_cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, self.rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(
merged_bboxes, merged_scores, self.rcnn_test_cfg.score_thr,
self.rcnn_test_cfg.nms_thr, self.rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels
class MaskTestMixin(object):
def simple_test_mask(self,
x,
img_meta,
det_bboxes,
det_labels,
rescale=False):
# image shape of the first image in the batch (only one)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
_bboxes = (det_bboxes[:, :4] * scale_factor
if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, det_bboxes, det_labels, img_shape,
self.rcnn_test_cfg, rescale)
return segm_result
def aug_test_mask(self,
feats,
img_metas,
det_bboxes,
det_labels,
rescale=False):
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
aug_masks = []
for x, img_meta in zip(feats, img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
mask_pred = self.mask_head(mask_feats)
# convert to numpy array to save memory
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas,
self.rcnn_test_cfg)
segm_result = self.mask_head.get_seg_masks(
merged_masks, _det_bboxes, det_labels,
img_metas[0]['shape_scale'][0], self.rcnn_test_cfg, rescale)
return segm_result
import torch
import torch.nn as nn
from .base import Detector
from .testing_mixins import RPNTestMixin, BBoxTestMixin
from .. import builder
from mmdet.core import bbox2roi, bbox2result, sample_proposals
class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
def __init__(self,
backbone,
neck=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(Detector, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.with_neck = True if neck is not None else False
if self.with_neck:
self.neck = builder.build_neck(neck)
self.with_rpn = True if rpn_head is not None else False
if self.with_rpn:
self.rpn_head = builder.build_rpn_head(rpn_head)
self.with_bbox = True if bbox_head is not None else False
if self.with_bbox:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
if pretrained is not None:
print('load model from: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_bbox:
self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
proposals=None):
losses = dict()
x = self.extract_feat(img)
if self.with_rpn:
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
else:
proposal_list = proposals
(pos_inds, neg_inds, pos_proposals, neg_proposals,
pos_assigned_gt_inds,
pos_gt_bboxes, pos_gt_labels) = sample_proposals(
proposal_list, gt_bboxes, gt_bboxes_ignore, gt_labels,
self.train_cfg.rcnn)
labels, label_weights, bbox_targets, bbox_weights = \
self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
return losses
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
x = self.extract_feat(img)
if proposals is None:
proposals = self.simple_test_rpn(x, img_meta)
if self.with_bbox:
# BUG proposals shape?
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, [proposals], rescale=rescale)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_result
else:
proposals[:, :4] /= img_meta['scale_factor'].float()
return proposals.cpu().numpy()
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
proposals = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.rpn_test_cfg)
det_bboxes, det_labels = self.aug_test_bboxes(
self.extract_feats(imgs), img_metas, proposals, self.rcnn_test_cfg)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0]['shape_scale'][0][-1]
bbox_result = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_result
......@@ -101,7 +101,7 @@ class FPN(nn.Module):
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
laterals[i - 1] += F.upsample(
laterals[i - 1] += F.interpolate(
laterals[i], scale_factor=2, mode='nearest')
# build outputs
......
......@@ -9,8 +9,7 @@ from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv,
weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import multi_apply
from ..utils import normal_init
from ..utils import multi_apply, normal_init
class RPNHead(nn.Module):
......@@ -66,14 +65,14 @@ class RPNHead(nn.Module):
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_shapes):
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors given a list of feature map sizes, and get valid flags
at the same time. (Extra padding regions should be marked as invalid)
"""
# calculate actual image shapes
padded_img_shapes = []
for img_shape in img_shapes:
h, w = img_shape[:2]
for img_meta in img_metas:
h, w = img_meta['img_shape'][:2]
padded_h = int(
np.ceil(h / self.coarsest_stride) * self.coarsest_stride)
padded_w = int(
......@@ -83,7 +82,7 @@ class RPNHead(nn.Module):
# len = feature levels
anchor_list = []
# len = imgs per gpu
valid_flag_list = [[] for _ in range(len(img_shapes))]
valid_flag_list = [[] for _ in range(len(img_metas))]
for i in range(len(featmap_sizes)):
anchor_stride = self.anchor_strides[i]
anchors = self.anchor_generators[i].grid_anchors(
......@@ -103,26 +102,22 @@ class RPNHead(nn.Module):
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
loss_cls = weighted_binary_cross_entropy(
rpn_cls_score,
labels,
label_weights,
ave_factor=num_total_samples)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
loss_cls = weighted_cross_entropy(
rpn_cls_score,
labels,
label_weights,
ave_factor=num_total_samples)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
......@@ -130,7 +125,7 @@ class RPNHead(nn.Module):
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
ave_factor=num_total_samples)
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
......@@ -158,8 +153,8 @@ class RPNHead(nn.Module):
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_shapes, cfg):
img_per_gpu = len(img_shapes)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
......@@ -167,7 +162,7 @@ class RPNHead(nn.Module):
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(img_per_gpu):
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
......@@ -177,10 +172,9 @@ class RPNHead(nn.Module):
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
img_shape = img_shapes[img_id]
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_shape, cfg)
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
......@@ -195,7 +189,7 @@ class RPNHead(nn.Module):
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = F.sigmoid(rpn_cls_score)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment