Commit 09c3bc4c authored by yhcao6's avatar yhcao6
Browse files

rename

parent 4fcdf6e9
...@@ -21,8 +21,12 @@ class OHEMSampler(BaseSampler): ...@@ -21,8 +21,12 @@ class OHEMSampler(BaseSampler):
self.bbox_roi_extractor = bbox_roi_extractor self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head self.bbox_head = bbox_head
def _sample_pos(self, assign_result, num_expected, bboxes=None, def _sample_pos(self,
feats=None, **kwargs): assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Hard sample some positive samples.""" """Hard sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
...@@ -42,12 +46,16 @@ class OHEMSampler(BaseSampler): ...@@ -42,12 +46,16 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)), label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None, bbox_targets=None,
bbox_weights=None, bbox_weights=None,
reduction=False)['loss_cls'] reduce=False)['loss_cls']
_, topk_loss_pos_inds = loss_pos.topk(num_expected) _, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds] return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, assign_result, num_expected, bboxes=None, def _sample_neg(self,
feats=None, **kwargs): assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Hard sample some negative samples.""" """Hard sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
...@@ -67,6 +75,6 @@ class OHEMSampler(BaseSampler): ...@@ -67,6 +75,6 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)), label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None, bbox_targets=None,
bbox_weights=None, bbox_weights=None,
reduction=False)['loss_cls'] reduce=False)['loss_cls']
_, topk_loss_neg_inds = loss_neg.topk(num_expected) _, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds] return neg_inds[topk_loss_neg_inds]
...@@ -11,11 +11,11 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None): ...@@ -11,11 +11,11 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
def weighted_cross_entropy(pred, label, weight, avg_factor=None, def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduction=True): reduce=True):
if avg_factor is None: if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.) avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none') raw = F.cross_entropy(pred, label, reduction='none')
if reduction: if reduce:
return torch.sum(raw * weight)[None] / avg_factor return torch.sum(raw * weight)[None] / avg_factor
else: else:
return raw * weight / avg_factor return raw * weight / avg_factor
......
...@@ -79,11 +79,11 @@ class BBoxHead(nn.Module): ...@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
return cls_reg_targets return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets, def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights, reduction=True): bbox_weights, reduce=True):
losses = dict() losses = dict()
if cls_score is not None: if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy( losses['loss_cls'] = weighted_cross_entropy(
cls_score, labels, label_weights, reduction=reduction) cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels) losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None: if bbox_pred is not None:
losses['loss_reg'] = weighted_smoothl1( losses['loss_reg'] = weighted_smoothl1(
......
...@@ -119,7 +119,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -119,7 +119,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
proposal_list[i], proposal_list[i],
gt_bboxes[i], gt_bboxes[i],
gt_labels[i], gt_labels[i],
feats=[xx[i][None] for xx in x]) feats=[lvl_feat[i][None] for lvl_feat in x])
assign_results.append(assign_result) assign_results.append(assign_result)
sampling_results.append(sampling_result) sampling_results.append(sampling_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment