Unverified Commit b3b74481 authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

Vectorize box decoding in FCOS (#6203)



* basic structure

* added constrains

* fixed errors

* thanks to vadim!

* addressing the comments and added docstrign

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 329b9789
......@@ -300,6 +300,42 @@ class BoxLinearCoder:
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
return pred_boxes
def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
"""
Vectorized version of `decode_single` method.
Args:
rel_codes (Tensor): encoded boxes
boxes (List[Tensor]): List of reference boxes.
Returns:
Tensor: the predicted boxes with the encoded relative box offsets.
.. note::
This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
"""
boxes = torch.stack(boxes).to(dtype=rel_codes.dtype)
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
if self.normalize_by_size:
boxes_w = boxes[..., 2] - boxes[..., 0]
boxes_h = boxes[..., 3] - boxes[..., 1]
list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
rel_codes = rel_codes * list_box_size
pred_boxes1 = ctr_x - rel_codes[..., 0]
pred_boxes2 = ctr_y - rel_codes[..., 1]
pred_boxes3 = ctr_x + rel_codes[..., 2]
pred_boxes4 = ctr_y + rel_codes[..., 3]
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
return pred_boxes
class Matcher:
"""
......
......@@ -87,14 +87,12 @@ class FCOSHead(nn.Module):
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
# regression loss: GIoU loss
# TODO: vectorize this instead of using a for loop
pred_boxes = [
self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
]
pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)
# amp issue: pred_boxes need to convert float
loss_bbox_reg = generalized_box_iou_loss(
torch.stack(pred_boxes)[foregroud_mask].float(),
pred_boxes[foregroud_mask],
torch.stack(all_gt_boxes_targets)[foregroud_mask],
reduction="sum",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment