Unverified Commit 96fa8204 authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

cleanup for box encoding and decoding in FCOS (#6277)

* cleaning up box decoding

* minor nits

* cleanup for box encoding also addded.
parent b30fa5c1
...@@ -30,8 +30,8 @@ class TestModelsDetectionUtils: ...@@ -30,8 +30,8 @@ class TestModelsDetectionUtils:
proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float() proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()
rel_codes = box_coder.encode_single(boxes, proposals) rel_codes = box_coder.encode(boxes, proposals)
pred_boxes = box_coder.decode_single(rel_codes, boxes) pred_boxes = box_coder.decode(rel_codes, boxes)
torch.allclose(proposals, pred_boxes) torch.allclose(proposals, pred_boxes)
@pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
......
...@@ -237,42 +237,10 @@ class BoxLinearCoder: ...@@ -237,42 +237,10 @@ class BoxLinearCoder:
""" """
self.normalize_by_size = normalize_by_size self.normalize_by_size = normalize_by_size
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
""" """
Encode a set of proposals with respect to some reference boxes Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.
"""
# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3])
# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[:, 0]
target_t = reference_boxes_ctr_y - proposals[:, 1]
target_r = proposals[:, 2] - reference_boxes_ctr_x
target_b = proposals[:, 3] - reference_boxes_ctr_y
targets = torch.stack((target_l, target_t, target_r, target_b), dim=1)
if self.normalize_by_size:
reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0]
reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1
)
targets = targets / reference_boxes_size
return targets
def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
vectorized version of `encode_single`
Args: Args:
reference_boxes (Tensor): reference boxes reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded proposals (Tensor): boxes to be encoded
...@@ -304,7 +272,8 @@ class BoxLinearCoder: ...@@ -304,7 +272,8 @@ class BoxLinearCoder:
targets = targets / reference_boxes_size targets = targets / reference_boxes_size
return targets return targets
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
""" """
From a set of original boxes and encoded relative box offsets, From a set of original boxes and encoded relative box offsets,
get the decoded boxes. get the decoded boxes.
...@@ -313,35 +282,6 @@ class BoxLinearCoder: ...@@ -313,35 +282,6 @@ class BoxLinearCoder:
rel_codes (Tensor): encoded boxes rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes. boxes (Tensor): reference boxes.
Returns:
Tensor: the predicted boxes with the encoded relative box offsets.
"""
boxes = boxes.to(rel_codes.dtype)
ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
if self.normalize_by_size:
boxes_w = boxes[:, 2] - boxes[:, 0]
boxes_h = boxes[:, 3] - boxes[:, 1]
boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1)
rel_codes = rel_codes * boxes_size
pred_boxes1 = ctr_x - rel_codes[:, 0]
pred_boxes2 = ctr_y - rel_codes[:, 1]
pred_boxes3 = ctr_x + rel_codes[:, 2]
pred_boxes4 = ctr_y + rel_codes[:, 3]
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
return pred_boxes
def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
"""
Vectorized version of `decode_single` method.
Args:
rel_codes (Tensor): encoded boxes
boxes (List[Tensor]): List of reference boxes.
Returns: Returns:
Tensor: the predicted boxes with the encoded relative box offsets. Tensor: the predicted boxes with the encoded relative box offsets.
...@@ -350,7 +290,7 @@ class BoxLinearCoder: ...@@ -350,7 +290,7 @@ class BoxLinearCoder:
""" """
boxes = torch.stack(boxes).to(dtype=rel_codes.dtype) boxes = boxes.to(dtype=rel_codes.dtype)
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2]) ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3]) ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
......
...@@ -74,7 +74,13 @@ class FCOSHead(nn.Module): ...@@ -74,7 +74,13 @@ class FCOSHead(nn.Module):
all_gt_classes_targets.append(gt_classes_targets) all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets) all_gt_boxes_targets.append(gt_boxes_targets)
all_gt_classes_targets = torch.stack(all_gt_classes_targets) # List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
all_gt_boxes_targets, all_gt_classes_targets, anchors = (
torch.stack(all_gt_boxes_targets),
torch.stack(all_gt_classes_targets),
torch.stack(anchors),
)
# compute foregroud # compute foregroud
foregroud_mask = all_gt_classes_targets >= 0 foregroud_mask = all_gt_classes_targets >= 0
num_foreground = foregroud_mask.sum().item() num_foreground = foregroud_mask.sum().item()
...@@ -84,14 +90,10 @@ class FCOSHead(nn.Module): ...@@ -84,14 +90,10 @@ class FCOSHead(nn.Module):
gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0 gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
# regression loss: GIoU loss
pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)
# List[Tensor] to Tensor conversion of `all_gt_boxes_target` and `anchors`
all_gt_boxes_targets, anchors = torch.stack(all_gt_boxes_targets), torch.stack(anchors)
# amp issue: pred_boxes need to convert float # amp issue: pred_boxes need to convert float
pred_boxes = self.box_coder.decode(bbox_regression, anchors)
# regression loss: GIoU loss
loss_bbox_reg = generalized_box_iou_loss( loss_bbox_reg = generalized_box_iou_loss(
pred_boxes[foregroud_mask], pred_boxes[foregroud_mask],
all_gt_boxes_targets[foregroud_mask], all_gt_boxes_targets[foregroud_mask],
...@@ -100,7 +102,7 @@ class FCOSHead(nn.Module): ...@@ -100,7 +102,7 @@ class FCOSHead(nn.Module):
# ctrness loss # ctrness loss
bbox_reg_targets = self.box_coder.encode_all(anchors, all_gt_boxes_targets) bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
if len(bbox_reg_targets) == 0: if len(bbox_reg_targets) == 0:
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1]) gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
...@@ -522,7 +524,7 @@ class FCOS(nn.Module): ...@@ -522,7 +524,7 @@ class FCOS(nn.Module):
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
labels_per_level = topk_idxs % num_classes labels_per_level = topk_idxs % num_classes
boxes_per_level = self.box_coder.decode_single( boxes_per_level = self.box_coder.decode(
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
) )
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment