Commit e84af8ce authored by Kai Chen's avatar Kai Chen
Browse files

minor fix

parent 8aaadf2c
......@@ -121,8 +121,8 @@ def anchor_target_single(flat_anchors,
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = gt_labels.new_zeros(num_valid_anchors)
label_weights = gt_labels.new_zeros(num_valid_anchors, dtype=torch.float)
labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
......
......@@ -258,6 +258,11 @@ class RetinaHead(nn.Module):
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4)
proposals = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
maxscores, _ = scores.max(dim=1)
_, topk_inds = maxscores.topk(cfg.nms_pre)
proposals = proposals[topk_inds, :]
scores = scores[topk_inds, :]
mlvl_proposals.append(proposals)
mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment