Unverified Commit 777d5d4b authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #492 from yhcao6/no-expand

remove expand loop in bbox head to speed up
parents bc4f7533 bd8fd27e
...@@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes, ...@@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes,
bbox_weights[:num_pos, :] = 1 bbox_weights[:num_pos, :] = 1
if num_neg > 0: if num_neg > 0:
label_weights[-num_neg:] = 1.0 label_weights[-num_neg:] = 1.0
if reg_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_classes)
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
......
...@@ -94,10 +94,16 @@ class BBoxHead(nn.Module): ...@@ -94,10 +94,16 @@ class BBoxHead(nn.Module):
cls_score, labels, label_weights, reduce=reduce) cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels) losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None: if bbox_pred is not None:
pos_inds = labels > 0
if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds]
else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_inds, labels[pos_inds]]
losses['loss_reg'] = weighted_smoothl1( losses['loss_reg'] = weighted_smoothl1(
bbox_pred, pos_bbox_pred,
bbox_targets, bbox_targets[pos_inds],
bbox_weights, bbox_weights[pos_inds],
avg_factor=bbox_targets.size(0)) avg_factor=bbox_targets.size(0))
return losses return losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment