Unverified Commit 699cb914 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Format the code and add yapf to travis (#1079)

* format the codebase with yapf

* add yapf to travis
parent 86cc430a
...@@ -2,7 +2,7 @@ dist: xenial ...@@ -2,7 +2,7 @@ dist: xenial
language: python language: python
install: install:
- pip install flake8 - pip install flake8 yapf
python: python:
- "3.5" - "3.5"
...@@ -11,3 +11,4 @@ python: ...@@ -11,3 +11,4 @@ python:
script: script:
- flake8 - flake8
- yapf -r -d --style .style.yapf mmdet/ tools/
\ No newline at end of file
...@@ -33,12 +33,14 @@ class AnchorGenerator(object): ...@@ -33,12 +33,14 @@ class AnchorGenerator(object):
ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1) ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1) hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)
# yapf: disable
base_anchors = torch.stack( base_anchors = torch.stack(
[ [
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
], ],
dim=-1).round() dim=-1).round()
# yapf: enable
return base_anchors return base_anchors
......
...@@ -62,12 +62,13 @@ def ga_loc_target(gt_bboxes_list, ...@@ -62,12 +62,13 @@ def ga_loc_target(gt_bboxes_list,
all_ignore_map = [] all_ignore_map = []
for lvl_id in range(num_lvls): for lvl_id in range(num_lvls):
h, w = featmap_sizes[lvl_id] h, w = featmap_sizes[lvl_id]
loc_targets = torch.zeros(img_per_gpu, loc_targets = torch.zeros(
1, img_per_gpu,
h, 1,
w, h,
device=gt_bboxes_list[0].device, w,
dtype=torch.float32) device=gt_bboxes_list[0].device,
dtype=torch.float32)
loc_weights = torch.full_like(loc_targets, -1) loc_weights = torch.full_like(loc_targets, -1)
ignore_map = torch.zeros_like(loc_targets) ignore_map = torch.zeros_like(loc_targets)
all_loc_targets.append(loc_targets) all_loc_targets.append(loc_targets)
...@@ -175,17 +176,18 @@ def ga_shape_target(approx_list, ...@@ -175,17 +176,18 @@ def ga_shape_target(approx_list,
if gt_bboxes_ignore_list is None: if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)] gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
(all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list, (all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list,
neg_inds_list) = multi_apply(ga_shape_target_single, neg_inds_list) = multi_apply(
approx_flat_list, ga_shape_target_single,
inside_flag_flat_list, approx_flat_list,
square_flat_list, inside_flag_flat_list,
gt_bboxes_list, square_flat_list,
gt_bboxes_ignore_list, gt_bboxes_list,
img_metas, gt_bboxes_ignore_list,
approxs_per_octave=approxs_per_octave, img_metas,
cfg=cfg, approxs_per_octave=approxs_per_octave,
sampling=sampling, cfg=cfg,
unmap_outputs=unmap_outputs) sampling=sampling,
unmap_outputs=unmap_outputs)
# no valid anchors # no valid anchors
if any([bbox_anchors is None for bbox_anchors in all_bbox_anchors]): if any([bbox_anchors is None for bbox_anchors in all_bbox_anchors]):
return None return None
......
...@@ -101,14 +101,12 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner): ...@@ -101,14 +101,12 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and ( if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
gt_bboxes_ignore.numel() > 0): gt_bboxes_ignore.numel() > 0):
if self.ignore_wrt_candidates: if self.ignore_wrt_candidates:
ignore_overlaps = bbox_overlaps(bboxes, ignore_overlaps = bbox_overlaps(
gt_bboxes_ignore, bboxes, gt_bboxes_ignore, mode='iof')
mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
else: else:
ignore_overlaps = bbox_overlaps(gt_bboxes_ignore, ignore_overlaps = bbox_overlaps(
bboxes, gt_bboxes_ignore, bboxes, mode='iof')
mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
......
...@@ -107,8 +107,9 @@ class MaxIoUAssigner(BaseAssigner): ...@@ -107,8 +107,9 @@ class MaxIoUAssigner(BaseAssigner):
num_gts, num_bboxes = overlaps.size(0), overlaps.size(1) num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default # 1. assign -1 by default
assigned_gt_inds = overlaps.new_full( assigned_gt_inds = overlaps.new_full((num_bboxes, ),
(num_bboxes, ), -1, dtype=torch.long) -1,
dtype=torch.long)
# for each anchor, which gt best overlaps with it # for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts # for each anchor, the max iou of all gts
......
...@@ -62,10 +62,10 @@ def bbox_target_single(pos_bboxes, ...@@ -62,10 +62,10 @@ def bbox_target_single(pos_bboxes,
def expand_target(bbox_targets, bbox_weights, labels, num_classes): def expand_target(bbox_targets, bbox_weights, labels, num_classes):
bbox_targets_expand = bbox_targets.new_zeros((bbox_targets.size(0), bbox_targets_expand = bbox_targets.new_zeros(
4 * num_classes)) (bbox_targets.size(0), 4 * num_classes))
bbox_weights_expand = bbox_weights.new_zeros((bbox_weights.size(0), bbox_weights_expand = bbox_weights.new_zeros(
4 * num_classes)) (bbox_weights.size(0), 4 * num_classes))
for i in torch.nonzero(labels > 0).squeeze(-1): for i in torch.nonzero(labels > 0).squeeze(-1):
start, end = labels[i] * 4, (labels[i] + 1) * 4 start, end = labels[i] * 4, (labels[i] + 1) * 4
bbox_targets_expand[i, start:end] = bbox_targets[i, :] bbox_targets_expand[i, start:end] = bbox_targets[i, :]
......
...@@ -279,8 +279,8 @@ def eval_map(det_results, ...@@ -279,8 +279,8 @@ def eval_map(det_results,
bbox[:, 3] - bbox[:, 1] + 1) bbox[:, 3] - bbox[:, 1] + 1)
for k, (min_area, max_area) in enumerate(area_ranges): for k, (min_area, max_area) in enumerate(area_ranges):
num_gts[k] += np.sum( num_gts[k] += np.sum(
np.logical_not(cls_gt_ignore[j]) & np.logical_not(cls_gt_ignore[j])
(gt_areas >= min_area) & (gt_areas < max_area)) & (gt_areas >= min_area) & (gt_areas < max_area))
# sort all det bboxes by score, also sort tp and fp # sort all det bboxes by score, also sort tp and fp
cls_dets = np.vstack(cls_dets) cls_dets = np.vstack(cls_dets)
num_dets = cls_dets.shape[0] num_dets = cls_dets.shape[0]
...@@ -312,11 +312,12 @@ def eval_map(det_results, ...@@ -312,11 +312,12 @@ def eval_map(det_results,
all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
all_num_gts = np.vstack( all_num_gts = np.vstack(
[cls_result['num_gts'] for cls_result in eval_results]) [cls_result['num_gts'] for cls_result in eval_results])
mean_ap = [ mean_ap = []
all_ap[all_num_gts[:, i] > 0, i].mean() for i in range(num_scales):
if np.any(all_num_gts[:, i] > 0) else 0.0 if np.any(all_num_gts[:, i] > 0):
for i in range(num_scales) mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
] else:
mean_ap.append(0.0)
else: else:
aps = [] aps = []
for cls_result in eval_results: for cls_result in eval_results:
...@@ -368,8 +369,8 @@ def print_map_summary(mean_ap, results, dataset=None): ...@@ -368,8 +369,8 @@ def print_map_summary(mean_ap, results, dataset=None):
for j in range(num_classes): for j in range(num_classes):
row_data = [ row_data = [
label_names[j], num_gts[i, j], results[j]['num_dets'], label_names[j], num_gts[i, j], results[j]['num_dets'],
'{:.3f}'.format(recalls[i, j]), '{:.3f}'.format( '{:.3f}'.format(recalls[i, j]),
precisions[i, j]), '{:.3f}'.format(aps[i, j]) '{:.3f}'.format(precisions[i, j]), '{:.3f}'.format(aps[i, j])
] ]
table_data.append(row_data) table_data.append(row_data)
table_data.append(['mAP', '', '', '', '', '{:.3f}'.format(mean_ap[i])]) table_data.append(['mAP', '', '', '', '', '{:.3f}'.format(mean_ap[i])])
......
...@@ -45,8 +45,9 @@ def multiclass_nms(multi_bboxes, ...@@ -45,8 +45,9 @@ def multiclass_nms(multi_bboxes,
_scores *= score_factors[cls_inds] _scores *= score_factors[cls_inds]
cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1) cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
cls_dets, _ = nms_op(cls_dets, **nms_cfg_) cls_dets, _ = nms_op(cls_dets, **nms_cfg_)
cls_labels = multi_bboxes.new_full( cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ),
(cls_dets.shape[0], ), i - 1, dtype=torch.long) i - 1,
dtype=torch.long)
bboxes.append(cls_dets) bboxes.append(cls_dets)
labels.append(cls_labels) labels.append(cls_labels)
if bboxes: if bboxes:
......
...@@ -15,6 +15,6 @@ __all__ = [ ...@@ -15,6 +15,6 @@ __all__ = [
'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler', 'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation', 'WIDERFaceDataset',
'WIDERFaceDataset', 'DATASETS', 'build_dataset' 'DATASETS', 'build_dataset'
] ]
...@@ -29,8 +29,8 @@ def _concat_dataset(cfg, default_args=None): ...@@ -29,8 +29,8 @@ def _concat_dataset(cfg, default_args=None):
def build_dataset(cfg, default_args=None): def build_dataset(cfg, default_args=None):
if cfg['type'] == 'RepeatDataset': if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args), dataset = RepeatDataset(
cfg['times']) build_dataset(cfg['dataset'], default_args), cfg['times'])
elif isinstance(cfg['ann_file'], (list, tuple)): elif isinstance(cfg['ann_file'], (list, tuple)):
dataset = _concat_dataset(cfg, default_args) dataset = _concat_dataset(cfg, default_args)
else: else:
......
...@@ -5,5 +5,5 @@ from .registry import DATASETS ...@@ -5,5 +5,5 @@ from .registry import DATASETS
@DATASETS.register_module @DATASETS.register_module
class CityscapesDataset(CocoDataset): class CityscapesDataset(CocoDataset):
CLASSES = ('person', 'rider', 'car', 'truck', 'bus', CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'train', 'motorcycle', 'bicycle') 'bicycle')
...@@ -115,8 +115,8 @@ class RandomCrop(object): ...@@ -115,8 +115,8 @@ class RandomCrop(object):
left = random.uniform(w - new_w) left = random.uniform(w - new_w)
top = random.uniform(h - new_h) top = random.uniform(h - new_h)
patch = np.array((int(left), int(top), int(left + new_w), patch = np.array(
int(top + new_h))) (int(left), int(top), int(left + new_w), int(top + new_h)))
overlaps = bbox_overlaps( overlaps = bbox_overlaps(
patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
if overlaps.min() < min_iou: if overlaps.min() < min_iou:
......
...@@ -30,21 +30,23 @@ class GARetinaHead(GuidedAnchorHead): ...@@ -30,21 +30,23 @@ class GARetinaHead(GuidedAnchorHead):
for i in range(self.stacked_convs): for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append( self.cls_convs.append(
ConvModule(chn, ConvModule(
self.feat_channels, chn,
3, self.feat_channels,
stride=1, 3,
padding=1, stride=1,
conv_cfg=self.conv_cfg, padding=1,
norm_cfg=self.norm_cfg)) conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append( self.reg_convs.append(
ConvModule(chn, ConvModule(
self.feat_channels, chn,
3, self.feat_channels,
stride=1, 3,
padding=1, stride=1,
conv_cfg=self.conv_cfg, padding=1,
norm_cfg=self.norm_cfg)) conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1) self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2, self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
...@@ -59,15 +61,13 @@ class GARetinaHead(GuidedAnchorHead): ...@@ -59,15 +61,13 @@ class GARetinaHead(GuidedAnchorHead):
self.feat_channels, self.feat_channels,
kernel_size=3, kernel_size=3,
deformable_groups=self.deformable_groups) deformable_groups=self.deformable_groups)
self.retina_cls = MaskedConv2d(self.feat_channels, self.retina_cls = MaskedConv2d(
self.num_anchors * self.feat_channels,
self.cls_out_channels, self.num_anchors * self.cls_out_channels,
3, 3,
padding=1) padding=1)
self.retina_reg = MaskedConv2d(self.feat_channels, self.retina_reg = MaskedConv2d(
self.num_anchors * 4, self.feat_channels, self.num_anchors * 4, 3, padding=1)
3,
padding=1)
def init_weights(self): def init_weights(self):
for m in self.cls_convs: for m in self.cls_convs:
......
...@@ -17,10 +17,8 @@ class GARPNHead(GuidedAnchorHead): ...@@ -17,10 +17,8 @@ class GARPNHead(GuidedAnchorHead):
super(GARPNHead, self).__init__(2, in_channels, **kwargs) super(GARPNHead, self).__init__(2, in_channels, **kwargs)
def _init_layers(self): def _init_layers(self):
self.rpn_conv = nn.Conv2d(self.in_channels, self.rpn_conv = nn.Conv2d(
self.feat_channels, self.in_channels, self.feat_channels, 3, padding=1)
3,
padding=1)
super(GARPNHead, self)._init_layers() super(GARPNHead, self)._init_layers()
def init_weights(self): def init_weights(self):
...@@ -43,19 +41,21 @@ class GARPNHead(GuidedAnchorHead): ...@@ -43,19 +41,21 @@ class GARPNHead(GuidedAnchorHead):
img_metas, img_metas,
cfg, cfg,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
losses = super(GARPNHead, self).loss(cls_scores, losses = super(GARPNHead, self).loss(
bbox_preds, cls_scores,
shape_preds, bbox_preds,
loc_preds, shape_preds,
gt_bboxes, loc_preds,
None, gt_bboxes,
img_metas, None,
cfg, img_metas,
gt_bboxes_ignore=gt_bboxes_ignore) cfg,
return dict(loss_rpn_cls=losses['loss_cls'], gt_bboxes_ignore=gt_bboxes_ignore)
loss_rpn_bbox=losses['loss_bbox'], return dict(
loss_anchor_shape=losses['loss_shape'], loss_rpn_cls=losses['loss_cls'],
loss_anchor_loc=losses['loss_loc']) loss_rpn_bbox=losses['loss_bbox'],
loss_anchor_shape=losses['loss_shape'],
loss_anchor_loc=losses['loss_loc'])
def get_bboxes_single(self, def get_bboxes_single(self,
cls_scores, cls_scores,
......
...@@ -282,13 +282,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -282,13 +282,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
mask_roi_extractor = self.mask_roi_extractor[i] mask_roi_extractor = self.mask_roi_extractor[i]
mask_head = self.mask_head[i] mask_head = self.mask_head[i]
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
segm_result = [ mask_classes = mask_head.num_classes - 1
[] for _ in range(mask_head.num_classes - 1) segm_result = [[] for _ in range(mask_classes)]
]
else: else:
_bboxes = ( _bboxes = (
det_bboxes[:, :4] * scale_factor det_bboxes[:, :4] *
if rescale else det_bboxes) scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes]) mask_rois = bbox2roi([_bboxes])
mask_feats = mask_roi_extractor( mask_feats = mask_roi_extractor(
x[:len(mask_roi_extractor.featmap_strides)], x[:len(mask_roi_extractor.featmap_strides)],
...@@ -321,13 +320,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -321,13 +320,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
if self.with_mask: if self.with_mask:
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
segm_result = [ mask_classes = self.mask_head[-1].num_classes - 1
[] for _ in range(self.mask_head[-1].num_classes - 1) segm_result = [[] for _ in range(mask_classes)]
]
else: else:
_bboxes = ( _bboxes = (
det_bboxes[:, :4] * scale_factor det_bboxes[:, :4] *
if rescale else det_bboxes) scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes]) mask_rois = bbox2roi([_bboxes])
aug_masks = [] aug_masks = []
for i in range(self.num_stages): for i in range(self.num_stages):
......
...@@ -205,9 +205,10 @@ class HybridTaskCascade(CascadeRCNN): ...@@ -205,9 +205,10 @@ class HybridTaskCascade(CascadeRCNN):
gt_bboxes_ignore = [None for _ in range(num_imgs)] gt_bboxes_ignore = [None for _ in range(num_imgs)]
for j in range(num_imgs): for j in range(num_imgs):
assign_result = bbox_assigner.assign( assign_result = bbox_assigner.assign(proposal_list[j],
proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], gt_bboxes[j],
gt_labels[j]) gt_bboxes_ignore[j],
gt_labels[j])
sampling_result = bbox_sampler.sample( sampling_result = bbox_sampler.sample(
assign_result, assign_result,
proposal_list[j], proposal_list[j],
...@@ -308,13 +309,12 @@ class HybridTaskCascade(CascadeRCNN): ...@@ -308,13 +309,12 @@ class HybridTaskCascade(CascadeRCNN):
if self.with_mask: if self.with_mask:
mask_head = self.mask_head[i] mask_head = self.mask_head[i]
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
segm_result = [ mask_classes = mask_head.num_classes - 1
[] for _ in range(mask_head.num_classes - 1) segm_result = [[] for _ in range(mask_classes)]
]
else: else:
_bboxes = ( _bboxes = (
det_bboxes[:, :4] * scale_factor det_bboxes[:, :4] *
if rescale else det_bboxes) scale_factor if rescale else det_bboxes)
mask_pred = self._mask_forward_test( mask_pred = self._mask_forward_test(
i, x, _bboxes, semantic_feat=semantic_feat) i, x, _bboxes, semantic_feat=semantic_feat)
segm_result = mask_head.get_seg_masks( segm_result = mask_head.get_seg_masks(
...@@ -342,13 +342,12 @@ class HybridTaskCascade(CascadeRCNN): ...@@ -342,13 +342,12 @@ class HybridTaskCascade(CascadeRCNN):
if self.with_mask: if self.with_mask:
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
segm_result = [ mask_classes = self.mask_head[-1].num_classes - 1
[] for _ in range(self.mask_head[-1].num_classes - 1) segm_result = [[] for _ in range(mask_classes)]
]
else: else:
_bboxes = ( _bboxes = (
det_bboxes[:, :4] * scale_factor det_bboxes[:, :4] *
if rescale else det_bboxes) scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes]) mask_rois = bbox2roi([_bboxes])
aug_masks = [] aug_masks = []
......
...@@ -91,9 +91,10 @@ class BBoxTestMixin(object): ...@@ -91,9 +91,10 @@ class BBoxTestMixin(object):
# after merging, bboxes will be rescaled to the original image size # after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes( merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms( det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
merged_bboxes, merged_scores, rcnn_test_cfg.score_thr, rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img) rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels return det_bboxes, det_labels
...@@ -121,9 +122,11 @@ class MaskTestMixin(object): ...@@ -121,9 +122,11 @@ class MaskTestMixin(object):
if self.with_shared_head: if self.with_shared_head:
mask_feats = self.shared_head(mask_feats) mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks( segm_result = self.mask_head.get_seg_masks(mask_pred, _bboxes,
mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape, det_labels,
scale_factor, rescale) self.test_cfg.rcnn,
ori_shape, scale_factor,
rescale)
return segm_result return segm_result
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels): def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
......
...@@ -125,9 +125,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -125,9 +125,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
gt_bboxes_ignore = [None for _ in range(num_imgs)] gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = [] sampling_results = []
for i in range(num_imgs): for i in range(num_imgs):
assign_result = bbox_assigner.assign( assign_result = bbox_assigner.assign(proposal_list[i],
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_bboxes[i],
gt_labels[i]) gt_bboxes_ignore[i],
gt_labels[i])
sampling_result = bbox_sampler.sample( sampling_result = bbox_sampler.sample(
assign_result, assign_result,
proposal_list[i], proposal_list[i],
...@@ -146,8 +147,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -146,8 +147,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
bbox_feats = self.shared_head(bbox_feats) bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = self.bbox_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats)
bbox_targets = self.bbox_head.get_target( bbox_targets = self.bbox_head.get_target(sampling_results,
sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) gt_bboxes, gt_labels,
self.train_cfg.rcnn)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
*bbox_targets) *bbox_targets)
losses.update(loss_bbox) losses.update(loss_bbox)
...@@ -179,8 +181,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -179,8 +181,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
mask_feats = bbox_feats[pos_inds] mask_feats = bbox_feats[pos_inds]
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
mask_targets = self.mask_head.get_target( mask_targets = self.mask_head.get_target(sampling_results,
sampling_results, gt_masks, self.train_cfg.rcnn) gt_masks,
self.train_cfg.rcnn)
pos_labels = torch.cat( pos_labels = torch.cat(
[res.pos_gt_labels for res in sampling_results]) [res.pos_gt_labels for res in sampling_results])
loss_mask = self.mask_head.loss(mask_pred, mask_targets, loss_mask = self.mask_head.loss(mask_pred, mask_targets,
......
...@@ -98,7 +98,7 @@ class FusedSemanticHead(nn.Module): ...@@ -98,7 +98,7 @@ class FusedSemanticHead(nn.Module):
x = self.conv_embedding(x) x = self.conv_embedding(x)
return mask_pred, x return mask_pred, x
@force_fp32(apply_to=('mask_pred',)) @force_fp32(apply_to=('mask_pred', ))
def loss(self, mask_pred, labels): def loss(self, mask_pred, labels):
labels = labels.squeeze(1).long() labels = labels.squeeze(1).long()
loss_semantic_seg = self.criterion(mask_pred, labels) loss_semantic_seg = self.criterion(mask_pred, labels)
......
...@@ -72,7 +72,7 @@ class SingleRoIExtractor(nn.Module): ...@@ -72,7 +72,7 @@ class SingleRoIExtractor(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls return target_lvls
@force_fp32(apply_to=('feats',), out_fp16=True) @force_fp32(apply_to=('feats', ), out_fp16=True)
def forward(self, feats, rois): def forward(self, feats, rois):
if len(feats) == 1: if len(feats) == 1:
return self.roi_layers[0](feats[0], rois) return self.roi_layers[0](feats[0], rois)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment