Commit cb68807f authored by Cao Yuhang's avatar Cao Yuhang
Browse files

remove expand loop in bbox head to speed up

parent d5a6b5dc
......@@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes,
bbox_weights[:num_pos, :] = 1
if num_neg > 0:
label_weights[-num_neg:] = 1.0
if reg_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_classes)
return labels, label_weights, bbox_targets, bbox_weights
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS
......@@ -94,10 +94,16 @@ class BBoxHead(nn.Module):
cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
pos_mask = labels > 0
if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask]
else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_mask, labels[pos_mask]]
losses['loss_reg'] = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
pos_bbox_pred,
bbox_targets[pos_mask],
bbox_weights[pos_mask],
avg_factor=bbox_targets.size(0))
return losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment