"vscode:/vscode.git/clone" did not exist on "7fcc8911cbced68a045063a4edc0507688fc78e6"
Unverified Commit f5cdc7b9 authored by Gopi Krishna Erabati's avatar Gopi Krishna Erabati Committed by GitHub
Browse files

[Fix] Move bev_range to ObjectRangeFilter's call for consistency with gt_bboxes_3d type (#717)

* fix bev_range problem with points in different coordinate systems

* changed checking instance on points to gt_bboxes_3d to make it compatible with test

* sort imports
parent 786065a7
...@@ -4,7 +4,8 @@ from mmcv import is_tuple_of ...@@ -4,7 +4,8 @@ from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmdet3d.core import VoxelGenerator from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes, box_np_ops)
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip from mmdet.datasets.pipelines import RandomFlip
from ..builder import OBJECTSAMPLERS from ..builder import OBJECTSAMPLERS
...@@ -699,7 +700,6 @@ class ObjectRangeFilter(object): ...@@ -699,7 +700,6 @@ class ObjectRangeFilter(object):
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
self.bev_range = self.pcd_range[[0, 1, 3, 4]]
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to filter objects by the range. """Call function to filter objects by the range.
...@@ -711,9 +711,16 @@ class ObjectRangeFilter(object): ...@@ -711,9 +711,16 @@ class ObjectRangeFilter(object):
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
keys are updated in the result dict. keys are updated in the result dict.
""" """
# Check points instance type and initialise bev_range
if isinstance(input_dict['gt_bboxes_3d'],
(LiDARInstance3DBoxes, DepthInstance3DBoxes)):
bev_range = self.pcd_range[[0, 1, 3, 4]]
elif isinstance(input_dict['gt_bboxes_3d'], CameraInstance3DBoxes):
bev_range = self.pcd_range[[0, 2, 3, 5]]
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
mask = gt_bboxes_3d.in_range_bev(self.bev_range) mask = gt_bboxes_3d.in_range_bev(bev_range)
gt_bboxes_3d = gt_bboxes_3d[mask] gt_bboxes_3d = gt_bboxes_3d[mask]
# mask is a torch tensor but gt_labels_3d is still numpy array # mask is a torch tensor but gt_labels_3d is still numpy array
# using mask to index gt_labels_3d will cause bug when # using mask to index gt_labels_3d will cause bug when
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment