Commit 435fe45b authored by zhangwenwei's avatar zhangwenwei
Browse files

More unittest and fix bugs

parent f28fbecb
...@@ -89,7 +89,7 @@ class BaseInstance3DBoxes(object): ...@@ -89,7 +89,7 @@ class BaseInstance3DBoxes(object):
pass pass
@abstractmethod @abstractmethod
def in_range(self, box_range): def in_range_3d(self, box_range):
"""Check whether the boxes are in the given range """Check whether the boxes are in the given range
Args: Args:
...@@ -102,6 +102,20 @@ class BaseInstance3DBoxes(object): ...@@ -102,6 +102,20 @@ class BaseInstance3DBoxes(object):
""" """
pass pass
@abstractmethod
def in_range_bev(self, box_range):
"""Check whether the boxes are in the given range
Args:
box_range (list | torch.Tensor): the range of box
(x_min, y_min, x_max, y_max)
Returns:
a binary vector, indicating whether each box is inside
the reference range.
"""
pass
@abstractmethod @abstractmethod
def scale(self, scale_factors): def scale(self, scale_factors):
"""Scale the box with horizontal and vertical scaling factors """Scale the box with horizontal and vertical scaling factors
......
...@@ -29,10 +29,10 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -29,10 +29,10 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Returns: Returns:
torch.Tensor: a tensor with center of each box. torch.Tensor: a tensor with center of each box.
""" """
bottom_center = self.bottom_center() bottom_center = self.bottom_center
gravity_center = torch.zeros_like(bottom_center) gravity_center = torch.zeros_like(bottom_center)
gravity_center[:, :2] = bottom_center[:, :2] gravity_center[:, :2] = bottom_center[:, :2]
gravity_center[:, 2] = bottom_center[:, 2] + bottom_center[:, 5] * 0.5 gravity_center[:, 2] = bottom_center[:, 2] + self.tensor[:, 5] * 0.5
return gravity_center return gravity_center
@property @property
...@@ -40,24 +40,44 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -40,24 +40,44 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate the coordinates of corners of all the boxes. """Calculate the coordinates of corners of all the boxes.
Convert the boxes to the form of Convert the boxes to the form of
(x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1) (x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z0, x1y1z1)
.. code-block:: none
(x0, y0, z1) + ----------- + (x1, y1, z1)
/| / |
/ | / |
(x0, y0, z1) + ----------- + + (x1, y1, z0)
| / . | /
| / oriign | /
(x0, y0, z0) + ----------- + (x1, y0, z0)
Returns: Returns:
torch.Tensor: corners of each box with size (N, 8, 3) torch.Tensor: corners of each box with size (N, 8, 3)
""" """
dims = self.tensor[:, 3:6] dims = self.dims
corners_norm = torch.from_numpy( corners_norm = torch.from_numpy(
np.stack(np.unravel_index(np.arange(2**3), [2] * 3), axis=1)).to( np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to(
device=dims.device, dtype=dims.dtype) device=dims.device, dtype=dims.dtype)
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]] corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
corners_norm = corners_norm - dims.new_tensor([0.5, 1.0, 0.5]) corners_norm = corners_norm - dims.new_tensor([0.5, 0.5, 0])
corners = dims.view([-1, 1, 3]) * corners_norm.reshape([1, 2**3, 3]) corners = dims.view([-1, 1, 3]) * corners_norm.reshape([1, 8, 3])
corners = rotation_3d_in_axis(corners, self.tensor[:, 6], axis=1) corners = rotation_3d_in_axis(corners, self.tensor[:, 6], axis=2)
corners += self.tensor[:, :3].view(-1, 1, 3) corners += self.tensor[:, :3].view(-1, 1, 3)
return corners return corners
@property
def dims(self):
"""Calculate the length in each dimension of all the boxes.
Convert the boxes to the form of (x_size, y_size, z_size)
Returns:
torch.Tensor: corners of each box with size (N, 8, 3)
"""
return self.tensor[:, 3:6]
@property @property
def nearset_bev(self): def nearset_bev(self):
"""Calculate the 2D bounding boxes in BEV without rotation """Calculate the 2D bounding boxes in BEV without rotation
......
import numpy as np import numpy as np
import torch import torch
from mmdet3d.core.bbox import LiDARInstance3DBoxes from mmdet3d.core.bbox import Box3DMode, LiDARInstance3DBoxes
def test_lidar_boxes3d(): def test_lidar_boxes3d():
...@@ -128,6 +128,11 @@ def test_lidar_boxes3d(): ...@@ -128,6 +128,11 @@ def test_lidar_boxes3d():
mask = boxes.nonempty() mask = boxes.nonempty()
assert (mask == expected_tensor).all() assert (mask == expected_tensor).all()
# test bbox in_range
expected_tensor = torch.tensor([1, 1, 0, 0, 0], dtype=torch.bool)
mask = boxes.in_range_3d([0, -20, -2, 22, 2, 5])
assert (mask == expected_tensor).all()
# test bbox indexing # test bbox indexing
index_boxes = boxes[2:5] index_boxes = boxes[2:5]
expected_tensor = torch.tensor([[ expected_tensor = torch.tensor([[
...@@ -171,3 +176,77 @@ def test_lidar_boxes3d(): ...@@ -171,3 +176,77 @@ def test_lidar_boxes3d():
# test iteration # test iteration
for i, box in enumerate(index_boxes): for i, box in enumerate(index_boxes):
torch.allclose(box, expected_tensor[i]) torch.allclose(box, expected_tensor[i])
# test properties
assert torch.allclose(boxes.bottom_center, boxes.tensor[:, :3])
expected_tensor = (
boxes.tensor[:, :3] - boxes.tensor[:, 3:6] *
(torch.tensor([0.5, 0.5, 0]) - torch.tensor([0.5, 0.5, 0.5])))
assert torch.allclose(boxes.gravity_center, expected_tensor)
boxes.limit_yaw()
assert (boxes.tensor[:, 6] <= np.pi / 2).all()
assert (boxes.tensor[:, 6] >= -np.pi / 2).all()
Box3DMode.convert(boxes, Box3DMode.LIDAR, Box3DMode.LIDAR)
expected_tesor = boxes.tensor.clone()
assert torch.allclose(expected_tesor, boxes.tensor)
boxes.flip()
boxes.flip()
boxes.limit_yaw()
assert torch.allclose(expected_tesor, boxes.tensor)
# test nearest_bev
expected_tensor = torch.tensor([[-0.5763, -3.9307, 2.8326, -2.1709],
[6.0819, -5.7075, 10.1143, -4.1589],
[26.5212, -7.9800, 28.7637, -6.5018],
[18.2686, -29.2617, 21.7681, -27.6929],
[27.3398, -18.3976, 29.0896, -14.6065]])
# the pytorch print loses some precision
assert torch.allclose(
boxes.nearset_bev, expected_tensor, rtol=1e-4, atol=1e-7)
# obtained by the print of the original implementation
expected_tensor = torch.tensor([[[2.4093e+00, -4.4784e+00, -1.9169e+00],
[2.4093e+00, -4.4784e+00, -2.5769e-01],
[-7.7767e-01, -3.2684e+00, -2.5769e-01],
[-7.7767e-01, -3.2684e+00, -1.9169e+00],
[3.0340e+00, -2.8332e+00, -1.9169e+00],
[3.0340e+00, -2.8332e+00, -2.5769e-01],
[-1.5301e-01, -1.6232e+00, -2.5769e-01],
[-1.5301e-01, -1.6232e+00, -1.9169e+00]],
[[9.8933e+00, -6.1340e+00, -1.8019e+00],
[9.8933e+00, -6.1340e+00, -2.2310e-01],
[5.9606e+00, -5.2427e+00, -2.2310e-01],
[5.9606e+00, -5.2427e+00, -1.8019e+00],
[1.0236e+01, -4.6237e+00, -1.8019e+00],
[1.0236e+01, -4.6237e+00, -2.2310e-01],
[6.3029e+00, -3.7324e+00, -2.2310e-01],
[6.3029e+00, -3.7324e+00, -1.8019e+00]],
[[2.8525e+01, -8.2534e+00, -1.4676e+00],
[2.8525e+01, -8.2534e+00, 2.0648e-02],
[2.6364e+01, -7.6525e+00, 2.0648e-02],
[2.6364e+01, -7.6525e+00, -1.4676e+00],
[2.8921e+01, -6.8292e+00, -1.4676e+00],
[2.8921e+01, -6.8292e+00, 2.0648e-02],
[2.6760e+01, -6.2283e+00, 2.0648e-02],
[2.6760e+01, -6.2283e+00, -1.4676e+00]],
[[2.1337e+01, -2.9870e+01, -1.9028e+00],
[2.1337e+01, -2.9870e+01, -4.9495e-01],
[1.8102e+01, -2.8535e+01, -4.9495e-01],
[1.8102e+01, -2.8535e+01, -1.9028e+00],
[2.1935e+01, -2.8420e+01, -1.9028e+00],
[2.1935e+01, -2.8420e+01, -4.9495e-01],
[1.8700e+01, -2.7085e+01, -4.9495e-01],
[1.8700e+01, -2.7085e+01, -1.9028e+00]],
[[2.6398e+01, -1.7530e+01, -1.7879e+00],
[2.6398e+01, -1.7530e+01, -2.9959e-01],
[2.8612e+01, -1.4452e+01, -2.9959e-01],
[2.8612e+01, -1.4452e+01, -1.7879e+00],
[2.7818e+01, -1.8552e+01, -1.7879e+00],
[2.7818e+01, -1.8552e+01, -2.9959e-01],
[3.0032e+01, -1.5474e+01, -2.9959e-01],
[3.0032e+01, -1.5474e+01, -1.7879e+00]]])
# the pytorch print loses some precision
assert torch.allclose(boxes.corners, expected_tensor, rtol=1e-4, atol=1e-7)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment