Commit 1e35964c authored by Kai Chen's avatar Kai Chen
Browse files

adjust the structure of detectors

parent 678f9334
from .detectors import Detector from .detectors import *
from .builder import *
import math import math
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from torchpack import load_checkpoint from mmcv.torchpack import load_checkpoint
def conv3x3(in_planes, out_planes, stride=1, dilation=1): def conv3x3(in_planes, out_planes, stride=1, dilation=1):
......
...@@ -60,7 +60,7 @@ class BBoxHead(nn.Module): ...@@ -60,7 +60,7 @@ class BBoxHead(nn.Module):
return cls_score, bbox_pred return cls_score, bbox_pred
def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes, def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes,
pos_gt_labels, rcnn_train_cfg): pos_gt_labels, rcnn_train_cfg):
reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes
cls_reg_targets = bbox_target( cls_reg_targets = bbox_target(
pos_proposals, pos_proposals,
...@@ -85,7 +85,7 @@ class BBoxHead(nn.Module): ...@@ -85,7 +85,7 @@ class BBoxHead(nn.Module):
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
ave_factor=bbox_targets.size(0)) avg_factor=bbox_targets.size(0))
return losses return losses
def get_det_bboxes(self, def get_det_bboxes(self,
......
import mmcv from mmcv import torchpack as tp
from mmcv import torchpack
from torch import nn from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
mask_heads) mask_heads, detectors)
__all__ = [ __all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor', 'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head' 'build_bbox_head', 'build_mask_head', 'build_detector'
] ]
def _build_module(cfg, parrent=None): def _build_module(cfg, parrent=None, default_args=None):
return cfg if isinstance(cfg, nn.Module) else torchpack.obj_from_dict( return cfg if isinstance(cfg, nn.Module) else tp.obj_from_dict(
cfg, parrent) cfg, parrent, default_args)
def build(cfg, parrent=None): def build(cfg, parrent=None, default_args=None):
if isinstance(cfg, list): if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent) for cfg_ in cfg] modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
return nn.Sequential(*modules) return nn.Sequential(*modules)
else: else:
return _build_module(cfg, parrent) return _build_module(cfg, parrent, default_args)
def build_backbone(cfg): def build_backbone(cfg):
...@@ -46,3 +45,7 @@ def build_bbox_head(cfg): ...@@ -46,3 +45,7 @@ def build_bbox_head(cfg):
def build_mask_head(cfg): def build_mask_head(cfg):
return build(cfg, mask_heads) return build(cfg, mask_heads)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .detector import Detector from .base import BaseDetector
from .rpn import RPN
__all__ = ['BaseDetector', 'RPN']
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
class BaseDetector(nn.Module):
"""Base class for detectors"""
__metaclass__ = ABCMeta
def __init__(self):
super(BaseDetector, self).__init__()
@abstractmethod
def init_weights(self):
pass
@abstractmethod
def extract_feat(self, imgs):
pass
def extract_feats(self, imgs):
if isinstance(imgs, torch.Tensor):
return self.extract_feat(imgs)
elif isinstance(imgs, list):
for img in imgs:
yield self.extract_feat(img)
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
pass
def forward_test(self, imgs, img_metas, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
def forward(self, img, img_meta, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
...@@ -8,6 +8,7 @@ from mmdet.core import (bbox2roi, bbox_mapping, split_combined_gt_polys, ...@@ -8,6 +8,7 @@ from mmdet.core import (bbox2roi, bbox_mapping, split_combined_gt_polys,
class Detector(nn.Module): class Detector(nn.Module):
def __init__(self, def __init__(self,
backbone, backbone,
neck=None, neck=None,
......
import mmcv
from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector
from .testing_mixins import RPNTestMixin
from .. import builder
class RPN(BaseDetector, RPNTestMixin):
def __init__(self,
backbone,
neck,
rpn_head,
train_cfg,
test_cfg,
pretrained=None):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
if pretrained is not None:
print('load model from: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None:
self.neck.init_weights()
self.rpn_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.neck is not None:
x = self.neck(x)
return x
def forward_train(self, img, img_meta, gt_bboxes=None):
if self.train_cfg.rpn.get('debug', False):
self.rpn_head.debug_imgs = tensor2imgs(img)
x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
losses = self.rpn_head.loss(*rpn_loss_inputs)
return losses
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
proposal_list = self.simple_test_rpn(x, img_meta, self.test_cfg.rpn)
if rescale:
for proposals, meta in zip(proposal_list, img_meta):
proposals[:, :4] /= meta['scale_factor']
# TODO: remove this restriction
return proposal_list[0].cpu().numpy()
def aug_test(self, imgs, img_metas, rescale=False):
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
if not rescale:
for proposals, img_meta in zip(proposal_list, img_metas[0]):
img_shape = img_meta['img_shape']
scale_factor = img_meta['scale_factor']
flip = img_meta['flip']
proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
scale_factor, flip)
# TODO: remove this restriction
return proposal_list[0].cpu().numpy()
def show_result(self, data, result, img_norm_cfg):
"""Show RPN proposals on the image.
Although we assume batch size is 1, this method supports arbitrary
batch size.
"""
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, result, top_k=20)
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals,
merge_aug_bboxes, merge_aug_masks, multiclass_nms)
class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
imgs_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(imgs_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, img_meta, rpn_test_cfg)
for proposals, img_meta in zip(aug_proposals, img_metas)
]
return merged_proposals
class BBoxTestMixin(object):
def simple_test_bboxes(self,
x,
img_meta,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
return det_bboxes, det_labels
def aug_test_bboxes(self, feats, img_metas, proposals, rcnn_test_cfg):
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
proposals = bbox_mapping(proposals[:, :4], img_shape, scale_factor,
flip)
rois = bbox2roi([proposals])
# recompute feature maps to save GPU memory
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
bboxes, scores = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
rescale=False,
nms_cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, self.rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(
merged_bboxes, merged_scores, self.rcnn_test_cfg.score_thr,
self.rcnn_test_cfg.nms_thr, self.rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels
class MaskTestMixin(object):
def simple_test_mask(self,
x,
img_meta,
det_bboxes,
det_labels,
rescale=False):
# image shape of the first image in the batch (only one)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
_bboxes = (det_bboxes[:, :4] * scale_factor
if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, det_bboxes, det_labels, img_shape,
self.rcnn_test_cfg, rescale)
return segm_result
def aug_test_mask(self,
feats,
img_metas,
det_bboxes,
det_labels,
rescale=False):
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
aug_masks = []
for x, img_meta in zip(feats, img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
mask_pred = self.mask_head(mask_feats)
# convert to numpy array to save memory
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas,
self.rcnn_test_cfg)
segm_result = self.mask_head.get_seg_masks(
merged_masks, _det_bboxes, det_labels,
img_metas[0]['shape_scale'][0], self.rcnn_test_cfg, rescale)
return segm_result
import torch
import torch.nn as nn
from .base import Detector
from .testing_mixins import RPNTestMixin, BBoxTestMixin
from .. import builder
from mmdet.core import bbox2roi, bbox2result, sample_proposals
class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
def __init__(self,
backbone,
neck=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(Detector, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.with_neck = True if neck is not None else False
if self.with_neck:
self.neck = builder.build_neck(neck)
self.with_rpn = True if rpn_head is not None else False
if self.with_rpn:
self.rpn_head = builder.build_rpn_head(rpn_head)
self.with_bbox = True if bbox_head is not None else False
if self.with_bbox:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
if pretrained is not None:
print('load model from: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_bbox:
self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
proposals=None):
losses = dict()
x = self.extract_feat(img)
if self.with_rpn:
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
else:
proposal_list = proposals
(pos_inds, neg_inds, pos_proposals, neg_proposals,
pos_assigned_gt_inds,
pos_gt_bboxes, pos_gt_labels) = sample_proposals(
proposal_list, gt_bboxes, gt_bboxes_ignore, gt_labels,
self.train_cfg.rcnn)
labels, label_weights, bbox_targets, bbox_weights = \
self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
return losses
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
x = self.extract_feat(img)
if proposals is None:
proposals = self.simple_test_rpn(x, img_meta)
if self.with_bbox:
# BUG proposals shape?
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, [proposals], rescale=rescale)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_result
else:
proposals[:, :4] /= img_meta['scale_factor'].float()
return proposals.cpu().numpy()
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
proposals = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.rpn_test_cfg)
det_bboxes, det_labels = self.aug_test_bboxes(
self.extract_feats(imgs), img_metas, proposals, self.rcnn_test_cfg)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0]['shape_scale'][0][-1]
bbox_result = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
return bbox_result
...@@ -101,7 +101,7 @@ class FPN(nn.Module): ...@@ -101,7 +101,7 @@ class FPN(nn.Module):
# build top-down path # build top-down path
used_backbone_levels = len(laterals) used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1): for i in range(used_backbone_levels - 1, 0, -1):
laterals[i - 1] += F.upsample( laterals[i - 1] += F.interpolate(
laterals[i], scale_factor=2, mode='nearest') laterals[i], scale_factor=2, mode='nearest')
# build outputs # build outputs
......
...@@ -9,8 +9,7 @@ from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv, ...@@ -9,8 +9,7 @@ from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv,
weighted_cross_entropy, weighted_smoothl1, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy) weighted_binary_cross_entropy)
from mmdet.ops import nms from mmdet.ops import nms
from ..utils import multi_apply from ..utils import multi_apply, normal_init
from ..utils import normal_init
class RPNHead(nn.Module): class RPNHead(nn.Module):
...@@ -66,14 +65,14 @@ class RPNHead(nn.Module): ...@@ -66,14 +65,14 @@ class RPNHead(nn.Module):
def forward(self, feats): def forward(self, feats):
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_shapes): def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors given a list of feature map sizes, and get valid flags """Get anchors given a list of feature map sizes, and get valid flags
at the same time. (Extra padding regions should be marked as invalid) at the same time. (Extra padding regions should be marked as invalid)
""" """
# calculate actual image shapes # calculate actual image shapes
padded_img_shapes = [] padded_img_shapes = []
for img_shape in img_shapes: for img_meta in img_metas:
h, w = img_shape[:2] h, w = img_meta['img_shape'][:2]
padded_h = int( padded_h = int(
np.ceil(h / self.coarsest_stride) * self.coarsest_stride) np.ceil(h / self.coarsest_stride) * self.coarsest_stride)
padded_w = int( padded_w = int(
...@@ -83,7 +82,7 @@ class RPNHead(nn.Module): ...@@ -83,7 +82,7 @@ class RPNHead(nn.Module):
# len = feature levels # len = feature levels
anchor_list = [] anchor_list = []
# len = imgs per gpu # len = imgs per gpu
valid_flag_list = [[] for _ in range(len(img_shapes))] valid_flag_list = [[] for _ in range(len(img_metas))]
for i in range(len(featmap_sizes)): for i in range(len(featmap_sizes)):
anchor_stride = self.anchor_strides[i] anchor_stride = self.anchor_strides[i]
anchors = self.anchor_generators[i].grid_anchors( anchors = self.anchor_generators[i].grid_anchors(
...@@ -103,26 +102,22 @@ class RPNHead(nn.Module): ...@@ -103,26 +102,22 @@ class RPNHead(nn.Module):
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights, def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1) labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1) label_weights = label_weights.contiguous().view(-1)
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1) 1).contiguous().view(-1)
loss_cls = weighted_binary_cross_entropy( criterion = weighted_binary_cross_entropy
rpn_cls_score,
labels,
label_weights,
ave_factor=num_total_samples)
else: else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2) 1).contiguous().view(-1, 2)
loss_cls = weighted_cross_entropy( criterion = weighted_cross_entropy
rpn_cls_score, loss_cls = criterion(
labels, rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
label_weights, # regression loss
ave_factor=num_total_samples) bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view( rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4) -1, 4)
loss_reg = weighted_smoothl1( loss_reg = weighted_smoothl1(
...@@ -130,7 +125,7 @@ class RPNHead(nn.Module): ...@@ -130,7 +125,7 @@ class RPNHead(nn.Module):
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
ave_factor=num_total_samples) avg_factor=num_total_samples)
return loss_cls, loss_reg return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg): def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
...@@ -158,8 +153,8 @@ class RPNHead(nn.Module): ...@@ -158,8 +153,8 @@ class RPNHead(nn.Module):
cfg=cfg) cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg) return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_shapes, cfg): def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
img_per_gpu = len(img_shapes) num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [ mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx], self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
...@@ -167,7 +162,7 @@ class RPNHead(nn.Module): ...@@ -167,7 +162,7 @@ class RPNHead(nn.Module):
for idx in range(len(featmap_sizes)) for idx in range(len(featmap_sizes))
] ]
proposal_list = [] proposal_list = []
for img_id in range(img_per_gpu): for img_id in range(num_imgs):
rpn_cls_score_list = [ rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach() rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores)) for idx in range(len(rpn_cls_scores))
...@@ -177,10 +172,9 @@ class RPNHead(nn.Module): ...@@ -177,10 +172,9 @@ class RPNHead(nn.Module):
for idx in range(len(rpn_bbox_preds)) for idx in range(len(rpn_bbox_preds))
] ]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list) assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
img_shape = img_shapes[img_id]
proposals = self._get_proposals_single( proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors, rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_shape, cfg) img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals) proposal_list.append(proposals)
return proposal_list return proposal_list
...@@ -195,7 +189,7 @@ class RPNHead(nn.Module): ...@@ -195,7 +189,7 @@ class RPNHead(nn.Module):
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2, rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1) 0).contiguous().view(-1)
rpn_cls_prob = F.sigmoid(rpn_cls_score) rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob scores = rpn_cls_prob
else: else:
rpn_cls_score = rpn_cls_score.permute(1, 2, rpn_cls_score = rpn_cls_score.permute(1, 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment