Commit 09c3bc4c authored by yhcao6's avatar yhcao6
Browse files

rename

parent 4fcdf6e9
......@@ -21,8 +21,12 @@ class OHEMSampler(BaseSampler):
self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head
def _sample_pos(self, assign_result, num_expected, bboxes=None,
feats=None, **kwargs):
def _sample_pos(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Hard sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
......@@ -42,12 +46,16 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction=False)['loss_cls']
reduce=False)['loss_cls']
_, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, assign_result, num_expected, bboxes=None,
feats=None, **kwargs):
def _sample_neg(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Hard sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
......@@ -67,6 +75,6 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction=False)['loss_cls']
reduce=False)['loss_cls']
_, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds]
......@@ -11,11 +11,11 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduction=True):
reduce=True):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
if reduction:
if reduce:
return torch.sum(raw * weight)[None] / avg_factor
else:
return raw * weight / avg_factor
......
......@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights, reduction=True):
bbox_weights, reduce=True):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
cls_score, labels, label_weights, reduction=reduction)
cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
losses['loss_reg'] = weighted_smoothl1(
......
......@@ -119,7 +119,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[xx[i][None] for xx in x])
feats=[lvl_feat[i][None] for lvl_feat in x])
assign_results.append(assign_result)
sampling_results.append(sampling_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment