Unverified Commit af861bf9 authored by gy77's avatar gy77 Committed by GitHub
Browse files

Add type hints in iou3d.py (#1989)

* add type hints in iou3d.py

* fix lint
parent bb4c65d8
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Optional
import torch import torch
from torch import Tensor
from ..utils import ext_loader from ..utils import ext_loader
...@@ -11,7 +13,7 @@ ext_module = ext_loader.load_ext('_ext', [ ...@@ -11,7 +13,7 @@ ext_module = ext_loader.load_ext('_ext', [
]) ])
def boxes_iou3d(boxes_a, boxes_b): def boxes_iou3d(boxes_a: Tensor, boxes_b: Tensor) -> Tensor:
"""Calculate boxes 3D IoU. """Calculate boxes 3D IoU.
Args: Args:
...@@ -30,7 +32,7 @@ def boxes_iou3d(boxes_a, boxes_b): ...@@ -30,7 +32,7 @@ def boxes_iou3d(boxes_a, boxes_b):
return ans_iou return ans_iou
def nms3d(boxes, scores, iou_threshold): def nms3d(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""3D NMS function GPU implementation (for BEV boxes). """3D NMS function GPU implementation (for BEV boxes).
Args: Args:
...@@ -54,7 +56,8 @@ def nms3d(boxes, scores, iou_threshold): ...@@ -54,7 +56,8 @@ def nms3d(boxes, scores, iou_threshold):
return keep return keep
def nms3d_normal(boxes, scores, iou_threshold): def nms3d_normal(boxes: Tensor, scores: Tensor,
iou_threshold: float) -> Tensor:
"""Normal 3D NMS function GPU implementation. The overlap of two boxes for """Normal 3D NMS function GPU implementation. The overlap of two boxes for
IoU calculation is defined as the exact overlapping area of the two boxes IoU calculation is defined as the exact overlapping area of the two boxes
WITH their yaw angle set to 0. WITH their yaw angle set to 0.
...@@ -79,7 +82,7 @@ def nms3d_normal(boxes, scores, iou_threshold): ...@@ -79,7 +82,7 @@ def nms3d_normal(boxes, scores, iou_threshold):
return order[keep[:num_out].cuda(boxes.device)].contiguous() return order[keep[:num_out].cuda(boxes.device)].contiguous()
def _xyxyr2xywhr(boxes): def _xyxyr2xywhr(boxes: Tensor) -> Tensor:
"""Convert [x1, y1, x2, y2, heading] box to [x, y, dx, dy, heading] box. """Convert [x1, y1, x2, y2, heading] box to [x, y, dx, dy, heading] box.
Args: Args:
...@@ -119,7 +122,11 @@ def boxes_iou_bev(boxes_a, boxes_b): ...@@ -119,7 +122,11 @@ def boxes_iou_bev(boxes_a, boxes_b):
return box_iou_rotated(_xyxyr2xywhr(boxes_a), _xyxyr2xywhr(boxes_b)) return box_iou_rotated(_xyxyr2xywhr(boxes_a), _xyxyr2xywhr(boxes_b))
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): def nms_bev(boxes: Tensor,
scores: Tensor,
thresh: float,
pre_max_size: Optional[int] = None,
post_max_size: Optional[int] = None) -> Tensor:
"""NMS function GPU implementation (for BEV boxes). """NMS function GPU implementation (for BEV boxes).
The overlap of two The overlap of two
...@@ -159,7 +166,7 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): ...@@ -159,7 +166,7 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
return keep return keep
def nms_normal_bev(boxes, scores, thresh): def nms_normal_bev(boxes: Tensor, scores: Tensor, thresh: float) -> Tensor:
"""Normal NMS function GPU implementation (for BEV boxes). """Normal NMS function GPU implementation (for BEV boxes).
The overlap of The overlap of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment