Unverified Commit 0467c9d7 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Vectorize RetinaNet's postprocessing (#2828)



* Vectorize operations, across all feaure levels.

* Remove unnecessary other_outputs variable.

* Split per feature level.

* Perform batched_nms across feature levels.

* Add extra parameter for limiting detections before and after nms.

* Restoring default threshold.

* Apply suggestions from code review
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* Renaming variable.
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent d94b45da
...@@ -22,5 +22,6 @@ htmlcov ...@@ -22,5 +22,6 @@ htmlcov
gen.yml gen.yml
.mypy_cache .mypy_cache
.vscode/ .vscode/
.idea/
*.orig *.orig
*-checkpoint.ipynb *-checkpoint.ipynb
\ No newline at end of file
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -291,6 +291,7 @@ class RetinaNet(nn.Module): ...@@ -291,6 +291,7 @@ class RetinaNet(nn.Module):
considered as positive during training. considered as positive during training.
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training. considered as negative during training.
topk_candidates (int): Number of best detections to keep before NMS.
Example: Example:
...@@ -339,7 +340,8 @@ class RetinaNet(nn.Module): ...@@ -339,7 +340,8 @@ class RetinaNet(nn.Module):
score_thresh=0.05, score_thresh=0.05,
nms_thresh=0.5, nms_thresh=0.5,
detections_per_img=300, detections_per_img=300,
fg_iou_thresh=0.5, bg_iou_thresh=0.4): fg_iou_thresh=0.5, bg_iou_thresh=0.4,
topk_candidates=1000):
super().__init__() super().__init__()
if not hasattr(backbone, "out_channels"): if not hasattr(backbone, "out_channels"):
...@@ -382,6 +384,7 @@ class RetinaNet(nn.Module): ...@@ -382,6 +384,7 @@ class RetinaNet(nn.Module):
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.nms_thresh = nms_thresh self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img self.detections_per_img = detections_per_img
self.topk_candidates = topk_candidates
# used only on torchscript mode # used only on torchscript mode
self._has_warned = False self._has_warned = False
...@@ -408,77 +411,63 @@ class RetinaNet(nn.Module): ...@@ -408,77 +411,63 @@ class RetinaNet(nn.Module):
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
def postprocess_detections(self, head_outputs, anchors, image_shapes): def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? class_logits = head_outputs['cls_logits']
box_regression = head_outputs['bbox_regression']
class_logits = head_outputs.pop('cls_logits') num_images = len(image_shapes)
box_regression = head_outputs.pop('bbox_regression')
other_outputs = head_outputs
device = class_logits.device
num_classes = class_logits.shape[-1]
scores = torch.sigmoid(class_logits)
# create labels for each score
labels = torch.arange(num_classes, device=device)
labels = labels.view(1, -1).expand_as(scores)
detections = torch.jit.annotate(List[Dict[str, Tensor]], []) detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \ for index in range(num_images):
enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): box_regression_per_image = [br[index] for br in box_regression]
logits_per_image = [cl[index] for cl in class_logits]
boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image) anchors_per_image, image_shape = anchors[index], image_shapes[index]
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()]
image_boxes = [] image_boxes = []
image_scores = [] image_scores = []
image_labels = [] image_labels = []
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})
for class_index in range(num_classes): for box_regression_per_level, logits_per_level, anchors_per_level in \
zip(box_regression_per_image, logits_per_image, anchors_per_image):
num_classes = logits_per_level.shape[-1]
# remove low scoring boxes # remove low scoring boxes
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh) scores_per_level = torch.sigmoid(logits_per_level).flatten()
boxes_per_class, scores_per_class, labels_per_class = \ keep_idxs = scores_per_level > self.score_thresh
boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] scores_per_level = scores_per_level[keep_idxs]
other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image] topk_idxs = torch.where(keep_idxs)[0]
# remove empty boxes # keep only topk scoring predictions
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2) num_topk = min(self.topk_candidates, topk_idxs.size(0))
boxes_per_class, scores_per_class, labels_per_class = \ scores_per_level, idxs = scores_per_level.topk(num_topk)
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] topk_idxs = topk_idxs[idxs]
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
# non-maximum suppression, independently done per class anchor_idxs = topk_idxs // num_classes
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh) labels_per_level = topk_idxs % num_classes
# keep only topk scoring predictions boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs],
keep = keep[:self.detections_per_img] anchors_per_level[anchor_idxs])
boxes_per_class, scores_per_class, labels_per_class = \ boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level)
image_labels.append(labels_per_level)
image_boxes.append(boxes_per_class) image_boxes = torch.cat(image_boxes, dim=0)
image_scores.append(scores_per_class) image_scores = torch.cat(image_scores, dim=0)
image_labels.append(labels_per_class) image_labels = torch.cat(image_labels, dim=0)
for k, v in other_outputs_per_class: # non-maximum suppression
if k not in image_other_outputs: keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
image_other_outputs[k] = [] keep = keep[:self.detections_per_img]
image_other_outputs[k].append(v)
detections.append({ detections.append({
'boxes': torch.cat(image_boxes, dim=0), 'boxes': image_boxes[keep],
'scores': torch.cat(image_scores, dim=0), 'scores': image_scores[keep],
'labels': torch.cat(image_labels, dim=0), 'labels': image_labels[keep],
}) })
for k, v in image_other_outputs.items():
detections[-1].update({k: torch.cat(v, dim=0)})
return detections return detections
def forward(self, images, targets=None): def forward(self, images, targets=None):
...@@ -557,8 +546,23 @@ class RetinaNet(nn.Module): ...@@ -557,8 +546,23 @@ class RetinaNet(nn.Module):
# compute the losses # compute the losses
losses = self.compute_loss(targets, head_outputs, anchors) losses = self.compute_loss(targets, head_outputs, anchors)
else: else:
# recover level sizes
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
HW = 0
for v in num_anchors_per_level:
HW += v
HWA = head_outputs['cls_logits'].size(1)
A = HWA // HW
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
# split outputs per level
split_head_outputs: Dict[str, List[Tensor]] = {}
for k in head_outputs:
split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
# compute the detections # compute the detections
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes) detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
if torch.jit.is_scripting(): if torch.jit.is_scripting():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment