Commit 82cd4892 authored by zhangwenwei's avatar zhangwenwei
Browse files

Add iou calculation in 3d structure

parent 0df95010
......@@ -50,6 +50,15 @@ class BaseInstance3DBoxes(object):
"""
return self.tensor[:, 3:6]
@property
def height(self):
"""Obtain the height of all the boxes.
Returns:
torch.Tensor: a vector with volume of each box.
"""
return self.tensor[:, 5]
@property
def center(self):
"""Calculate the center of all the boxes.
......@@ -275,3 +284,19 @@ class BaseInstance3DBoxes(object):
Yield a box as a Tensor of shape (4,) at a time.
"""
yield from self.tensor
@classmethod
def overlaps(cls, boxes1, boxes2, mode='iou', aligned=False):
"""Calculate overlaps of two boxes
Args:
boxes1 (:obj:BaseInstanceBoxes): boxes 1 contain N boxes
boxes2 (:obj:BaseInstanceBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
aligned (bool, optional): Whether the boxes are aligned.
Defaults to False.
Returns:
torch.Tensor: Calculated iou of boxes
"""
pass
import numpy as np
import torch
from mmdet3d.ops.iou3d import iou3d_cuda
from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis
......@@ -29,6 +30,15 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Each row is (x, y, z, x_size, y_size, z_size, yaw, ...).
"""
@property
def height(self):
"""Obtain the height of all the boxes.
Returns:
torch.Tensor: a vector with volume of each box.
"""
return self.tensor[:, 4]
@property
def gravity_center(self):
"""Calculate the gravity center of all the boxes.
......@@ -84,6 +94,32 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
corners += self.tensor[:, :3].view(-1, 1, 3)
return corners
@property
def bev(self, mode='XYWHR'):
"""Calculate the 2D bounding boxes in BEV with rotation
Args:
mode (str): The mode of BEV boxes. Default to 'XYWHR'.
Returns:
torch.Tensor: a nx5 tensor of 2D BEV box of each box.
"""
boxes_xywhr = self.tensor[:, [0, 2, 3, 5, 6]]
if mode == 'XYWHR':
return boxes_xywhr
elif mode == 'XYXYR':
boxes = torch.zeros_like(boxes_xywhr)
boxes[:, 0] = boxes_xywhr[:, 0] - boxes_xywhr[2]
boxes[:, 1] = boxes_xywhr[:, 1] - boxes_xywhr[3]
boxes[:, 2] = boxes_xywhr[:, 0] + boxes_xywhr[2]
boxes[:, 3] = boxes_xywhr[:, 1] + boxes_xywhr[3]
boxes[:, 4] = boxes_xywhr[:, 4]
return boxes
else:
raise ValueError(
'Only support mode to be either "XYWHR" or "XYXYR",'
f'got {mode}')
@property
def nearset_bev(self):
"""Calculate the 2D bounding boxes in BEV without rotation
......@@ -92,7 +128,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
torch.Tensor: a tensor of 2D BEV box of each box.
"""
# Obtain BEV boxes with rotation in XZWHR format
bev_rotated_boxes = self.tensor[:, [0, 2, 3, 5, 6]]
bev_rotated_boxes = self.bev
# convert the rotation to a valid range
rotations = bev_rotated_boxes[:, -1]
normed_rotations = torch.abs(limit_period(rotations, 0.5, np.pi))
......@@ -158,3 +194,55 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
& (self.tensor[:, 0] < box_range[2])
& (self.tensor[:, 2] < box_range[3]))
return in_range_flags
@classmethod
def overlaps(cls, boxes1, boxes2, mode='iou'):
"""Calculate overlaps of two boxes
Args:
boxes1 (:obj:CameraInstance3DBoxes): boxes 1 contain N boxes
boxes2 (:obj:CameraInstance3DBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
Returns:
torch.Tensor: Calculated iou of boxes
"""
assert isinstance(boxes1, CameraInstance3DBoxes)
assert isinstance(boxes2, CameraInstance3DBoxes)
assert mode in ['iou', 'iof']
# height overlap
boxes1_height_max = (boxes1.tensor[:, 1] + boxes1.height).view(-1, 1)
boxes1_height_min = boxes1.tensor[:, 1].view(-1, 1)
boxes2_height_max = (boxes2.tensor[:, 1] + boxes2.height).view(1, -1)
boxes2_height_min = boxes2.tensor[:, 1].view(1, -1)
max_of_min = torch.max(boxes1_height_min, boxes2_height_min)
min_of_max = torch.min(boxes1_height_max, boxes2_height_max)
overlaps_h = torch.clamp(min_of_max - max_of_min, min=0)
# obtain BEV boxes in XYXYR format
boxes1_bev = boxes1.bev(mode='XYXYR')
boxes2_bev = boxes2.bev(mode='XYXYR')
# bev overlap
overlaps_bev = boxes1_bev.new_zeros(
(boxes1_bev.shape[0], boxes2_bev.shape[0])).cuda() # (N, M)
iou3d_cuda.boxes_overlap_bev_gpu(boxes1_bev.contiguous().cuda(),
boxes2_bev.contiguous().cuda(),
overlaps_bev)
# 3d iou
overlaps_3d = overlaps_bev.to(boxes1.device) * overlaps_h
volume1 = boxes1.volume.view(-1, 1)
volume2 = boxes2.volume.view(1, -1)
if mode == 'iou':
# the clamp func is used to avoid division of 0
iou3d = overlaps_3d / torch.clamp(
volume1 + volume2 - overlaps_3d, min=1e-8)
else:
iou3d = overlaps_3d / torch.clamp(volume1, min=1e-8)
return iou3d
import numpy as np
import torch
from mmdet3d.ops.iou3d import iou3d_cuda
from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis
......@@ -79,6 +80,32 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
corners += self.tensor[:, :3].view(-1, 1, 3)
return corners
@property
def bev(self, mode='XYWHR'):
"""Calculate the 2D bounding boxes in BEV with rotation
Args:
mode (str): The mode of BEV boxes. Default to 'XYWHR'.
Returns:
torch.Tensor: a nx5 tensor of 2D BEV box of each box.
"""
boxes_xywhr = self.tensor[:, [0, 1, 3, 4, 6]]
if mode == 'XYWHR':
return boxes_xywhr
elif mode == 'XYXYR':
boxes = torch.zeros_like(boxes_xywhr)
boxes[:, 0] = boxes_xywhr[:, 0] - boxes_xywhr[2]
boxes[:, 1] = boxes_xywhr[:, 1] - boxes_xywhr[3]
boxes[:, 2] = boxes_xywhr[:, 0] + boxes_xywhr[2]
boxes[:, 3] = boxes_xywhr[:, 1] + boxes_xywhr[3]
boxes[:, 4] = boxes_xywhr[:, 4]
return boxes
else:
raise ValueError(
'Only support mode to be either "XYWHR" or "XYXYR",'
f'got {mode}')
@property
def nearset_bev(self):
"""Calculate the 2D bounding boxes in BEV without rotation
......@@ -87,7 +114,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
torch.Tensor: a tensor of 2D BEV box of each box.
"""
# Obtain BEV boxes with rotation in XYWHR format
bev_rotated_boxes = self.tensor[:, [0, 1, 3, 4, 6]]
bev_rotated_boxes = self.bev
# convert the rotation to a valid range
rotations = bev_rotated_boxes[:, -1]
normed_rotations = torch.abs(limit_period(rotations, 0.5, np.pi))
......@@ -153,3 +180,55 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
& (self.tensor[:, 0] < box_range[2])
& (self.tensor[:, 1] < box_range[3]))
return in_range_flags
@classmethod
def overlaps(cls, boxes1, boxes2, mode='iou'):
"""Calculate overlaps of two boxes
Args:
boxes1 (:obj:LiDARInstanceBoxes): boxes 1 contain N boxes
boxes2 (:obj:LiDARInstanceBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
Returns:
torch.Tensor: Calculated iou of boxes
"""
assert isinstance(boxes1, LiDARInstance3DBoxes)
assert isinstance(boxes2, LiDARInstance3DBoxes)
assert mode in ['iou', 'iof']
# height overlap
boxes1_height_max = (boxes1.tensor[:, 2] + boxes1.height).view(-1, 1)
boxes1_height_min = boxes1.tensor[:, 2].view(-1, 1)
boxes2_height_max = (boxes2.tensor[:, 2] + boxes2.height).view(1, -1)
boxes2_height_min = boxes2.tensor[:, 2].view(1, -1)
max_of_min = torch.max(boxes1_height_min, boxes2_height_min)
min_of_max = torch.min(boxes1_height_max, boxes2_height_max)
overlaps_h = torch.clamp(min_of_max - max_of_min, min=0)
# obtain BEV boxes in XYXYR format
boxes1_bev = boxes1.bev(mode='XYXYR')
boxes2_bev = boxes2.bev(mode='XYXYR')
# bev overlap
overlaps_bev = boxes1_bev.new_zeros(
(boxes1_bev.shape[0], boxes2_bev.shape[0])).cuda() # (N, M)
iou3d_cuda.boxes_overlap_bev_gpu(boxes1_bev.contiguous().cuda(),
boxes2_bev.contiguous().cuda(),
overlaps_bev)
# 3d iou
overlaps_3d = overlaps_bev.to(boxes1.device) * overlaps_h
volume1 = boxes1.volume.view(-1, 1)
volume2 = boxes2.volume.view(1, -1)
if mode == 'iou':
# the clamp func is used to avoid division of 0
iou3d = overlaps_3d / torch.clamp(
volume1 + volume2 - overlaps_3d, min=1e-8)
else:
iou3d = overlaps_3d / torch.clamp(volume1, min=1e-8)
return iou3d
......@@ -598,3 +598,60 @@ def test_camera_boxes3d():
# the pytorch print loses some precision
assert torch.allclose(boxes.corners, expected_tensor, rtol=1e-4, atol=1e-7)
def test_boxes3d_overlaps():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
# Test LiDAR boxes 3D overlaps
boxes1_tensor = torch.tensor(
[[1.8, -2.5 - 1.8, 1.75, 3.39, 1.65, 1.6615927],
[8.9, -2.5, -1.6, 1.54, 4.01, 1.57, 1.5215927],
[28.3, 0.5, -1.3, 1.47, 2.23, 1.48, 4.7115927],
[31.3, -8.2, -1.6, 1.74, 3.77, 1.48, 0.35159278]],
device='cuda')
boxes1 = LiDARInstance3DBoxes(boxes1_tensor)
boxes2_tensor = torch.tensor([[1.2, -3.0, -1.9, 1.8, 3.4, 1.7, 1.9],
[8.1, -2.9, -1.8, 1.5, 4.1, 1.6, 1.8],
[20.1, -28.5, -1.9, 1.6, 3.5, 1.4, 5.1],
[28.2, -16.5, -1.8, 1.7, 3.8, 1.5, 0.6]],
device='cuda')
boxes2 = LiDARInstance3DBoxes(boxes2_tensor)
from mmdet3d.ops.iou3d import boxes3d_to_bev_torch_lidar
expected_tensor = boxes3d_to_bev_torch_lidar(boxes1_tensor, boxes2_tensor)
overlaps_3d = boxes1.overlaps(boxes1, boxes2)
assert torch.allclose(expected_tensor, overlaps_3d)
# Test camera boxes 3D overlaps
boxes1_tensor = torch.tensor(
[[1.8, -2.5 - 1.8, 1.75, 3.39, 1.65, 1.6615927],
[8.9, -2.5, -1.6, 1.54, 4.01, 1.57, 1.5215927],
[28.3, 0.5, -1.3, 1.47, 2.23, 1.48, 4.7115927],
[31.3, -8.2, -1.6, 1.74, 3.77, 1.48, 0.35159278]],
device='cuda')
cam_boxes1_tensor = Box3DMode.convert(boxes1_tensor, Box3DMode.LIDAR,
Box3DMode.CAM)
cam_boxes1 = CameraInstance3DBoxes(cam_boxes1_tensor)
boxes2_tensor = torch.tensor([[1.2, -3.0, -1.9, 1.8, 3.4, 1.7, 1.9],
[8.1, -2.9, -1.8, 1.5, 4.1, 1.6, 1.8],
[20.1, -28.5, -1.9, 1.6, 3.5, 1.4, 5.1],
[28.2, -16.5, -1.8, 1.7, 3.8, 1.5, 0.6]],
device='cuda')
cam_boxes2_tensor = Box3DMode.convert(boxes2_tensor, Box3DMode.LIDAR,
Box3DMode.CAM)
cam_boxes2 = CameraInstance3DBoxes(cam_boxes2_tensor)
cam_overlaps_3d = cam_boxes1.overlaps(cam_boxes1, cam_boxes2)
from mmdet3d.ops.iou3d import boxes3d_to_bev_torch_camera
expected_tensor = boxes3d_to_bev_torch_camera(boxes1_tensor, boxes2_tensor)
assert torch.allclose(expected_tensor, cam_overlaps_3d)
assert torch.allclose(cam_overlaps_3d, overlaps_3d)
with pytest.raises(AssertionError):
cam_boxes1.overlaps(cam_boxes1, boxes1)
with pytest.raises(AssertionError):
boxes1.overlaps(cam_boxes1, boxes1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment