"vscode:/vscode.git/clone" did not exist on "9c6f46473acec0a1150e32e4feaa7ac55df789c9"
Commit ee7e679a authored by Mordekaiser's avatar Mordekaiser Committed by Kai Chen
Browse files

fix OHEM with cascade rcnn (#373)

* fix OHEM with cascade_rcnn

* fix OHEM with cascade_rcnn

* delete space

* delete white space

* delete unused lib

* Delete cascade_rcnn_ohem_r101_fpn_1x.py

* fix unreasonable code

* fix Single quote

* fix code style

* fix code style

* fix file permission
parent 2800bf3a
...@@ -15,8 +15,13 @@ class OHEMSampler(BaseSampler): ...@@ -15,8 +15,13 @@ class OHEMSampler(BaseSampler):
**kwargs): **kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals) add_gt_as_proposals)
self.bbox_roi_extractor = context.bbox_roi_extractor if not hasattr(context, 'num_stages'):
self.bbox_head = context.bbox_head self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head
else:
self.bbox_roi_extractor = context.bbox_roi_extractor[
context.current_stage]
self.bbox_head = context.bbox_head[context.current_stage]
def hard_mining(self, inds, num_expected, bboxes, labels, feats): def hard_mining(self, inds, num_expected, bboxes, labels, feats):
with torch.no_grad(): with torch.no_grad():
......
...@@ -7,7 +7,7 @@ from .base import BaseDetector ...@@ -7,7 +7,7 @@ from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, from mmdet.core import (build_assigner, bbox2roi, bbox2result, build_sampler,
merge_aug_masks) merge_aug_masks)
...@@ -131,17 +131,31 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -131,17 +131,31 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
proposal_list = proposals proposal_list = proposals
for i in range(self.num_stages): for i in range(self.num_stages):
self.current_stage = i
rcnn_train_cfg = self.train_cfg.rcnn[i] rcnn_train_cfg = self.train_cfg.rcnn[i]
lw = self.train_cfg.stage_loss_weights[i] lw = self.train_cfg.stage_loss_weights[i]
# assign gts and sample proposals # assign gts and sample proposals
assign_results, sampling_results = multi_apply( sampling_results = []
assign_and_sample, if self.with_bbox or self.with_mask:
proposal_list, bbox_assigner = build_assigner(rcnn_train_cfg.assigner)
gt_bboxes, bbox_sampler = build_sampler(
gt_bboxes_ignore, rcnn_train_cfg.sampler, context=self)
gt_labels, num_imgs = img.size(0)
cfg=rcnn_train_cfg) if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
for j in range(num_imgs):
assign_result = bbox_assigner.assign(
proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
gt_labels[j])
sampling_result = bbox_sampler.sample(
assign_result,
proposal_list[j],
gt_bboxes[j],
gt_labels[j],
feats=[lvl_feat[j][None] for lvl_feat in x])
sampling_results.append(sampling_result)
# bbox head forward and loss # bbox head forward and loss
bbox_roi_extractor = self.bbox_roi_extractor[i] bbox_roi_extractor = self.bbox_roi_extractor[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment