Unverified Commit ba0d665b authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

Vectorize box encoding in FCOS (#6278)



* intial structure

* fixed types of few variables

* remove the commented code

* list -> List

* encode method will take input as tensors instead of list of tensor
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 9b84859e
......@@ -270,6 +270,40 @@ class BoxLinearCoder:
return targets
def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
vectorized version of `encode_single`
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.
"""
# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[..., 0]
target_t = reference_boxes_ctr_y - proposals[..., 1]
target_r = proposals[..., 2] - reference_boxes_ctr_x
target_b = proposals[..., 3] - reference_boxes_ctr_y
targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
if self.normalize_by_size:
reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
)
targets = targets / reference_boxes_size
return targets
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
......
......@@ -88,19 +88,20 @@ class FCOSHead(nn.Module):
pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)
# List[Tensor] to Tensor conversion of `all_gt_boxes_target` and `anchors`
all_gt_boxes_targets, anchors = torch.stack(all_gt_boxes_targets), torch.stack(anchors)
# amp issue: pred_boxes need to convert float
loss_bbox_reg = generalized_box_iou_loss(
pred_boxes[foregroud_mask],
torch.stack(all_gt_boxes_targets)[foregroud_mask],
all_gt_boxes_targets[foregroud_mask],
reduction="sum",
)
# ctrness loss
bbox_reg_targets = [
self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
bbox_reg_targets = self.box_coder.encode_all(anchors, all_gt_boxes_targets)
if len(bbox_reg_targets) == 0:
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment