Commit 4fcdf6e9 authored by yhcao6's avatar yhcao6
Browse files

rename reduction of loss

parent 092b97f6
......@@ -42,7 +42,7 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
reduction=False)['loss_cls']
_, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds]
......@@ -67,6 +67,6 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
reduction=False)['loss_cls']
_, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds]
......@@ -11,13 +11,13 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduction='elementwise_sum'):
reduction=True):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
if reduction == 'elementwise_sum':
if reduction:
return torch.sum(raw * weight)[None] / avg_factor
elif reduction == 'none':
else:
return raw * weight / avg_factor
......
......@@ -79,7 +79,7 @@ class BBoxHead(nn.Module):
return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights, reduction='elementwise_sum'):
bbox_weights, reduction=True):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment