"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "7c66f0c2e80fa3e401d2a82042d05d14c9169a5d"
Unverified Commit 6c03a971 authored by Tai-Wang's avatar Tai-Wang Committed by GitHub
Browse files

Release v1.1.0rc1

Release v1.1.0rc1
parents 9611c2d0 ca42c312
...@@ -280,7 +280,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -280,7 +280,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
overlaps_h = torch.clamp(heighest_of_bottom - lowest_of_top, min=0) overlaps_h = torch.clamp(heighest_of_bottom - lowest_of_top, min=0)
return overlaps_h return overlaps_h
def convert_to(self, dst, rt_mat=None): def convert_to(self, dst, rt_mat=None, correct_yaw=False):
"""Convert self to ``dst`` mode. """Convert self to ``dst`` mode.
Args: Args:
...@@ -291,14 +291,21 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -291,14 +291,21 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
The conversion from ``src`` coordinates to ``dst`` coordinates The conversion from ``src`` coordinates to ``dst`` coordinates
usually comes along the change of sensors, e.g., from camera usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix. to LiDAR. This requires a transformation matrix.
correct_yaw (bool): Whether to convert the yaw angle to the target
coordinate. Defaults to False.
Returns: Returns:
:obj:`BaseInstance3DBoxes`: :obj:`BaseInstance3DBoxes`:
The converted box of the same type in the ``dst`` mode. The converted box of the same type in the ``dst`` mode.
""" """
from .box_3d_mode import Box3DMode from .box_3d_mode import Box3DMode
# TODO: always set correct_yaw=True
return Box3DMode.convert( return Box3DMode.convert(
box=self, src=Box3DMode.CAM, dst=dst, rt_mat=rt_mat) box=self,
src=Box3DMode.CAM,
dst=dst,
rt_mat=rt_mat,
correct_yaw=correct_yaw)
def points_in_boxes_part(self, points, boxes_override=None): def points_in_boxes_part(self, points, boxes_override=None):
"""Find the box in which each point is. """Find the box in which each point is.
......
...@@ -41,7 +41,7 @@ class Coord3DMode(IntEnum): ...@@ -41,7 +41,7 @@ class Coord3DMode(IntEnum):
v v
down y down y
The relative coordinate of bottom center in a CAM box is [0.5, 1.0, 0.5], The relative coordinate of bottom center in a CAM box is (0.5, 1.0, 0.5),
and the yaw is around the y axis, thus the rotation axis=1. and the yaw is around the y axis, thus the rotation axis=1.
Coordinates in Depth mode: Coordinates in Depth mode:
......
...@@ -14,7 +14,7 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -14,7 +14,7 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
.. code-block:: none .. code-block:: none
up z y front (yaw=-0.5*pi) up z y front (yaw=0.5*pi)
^ ^ ^ ^
| / | /
| / | /
......
...@@ -174,7 +174,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -174,7 +174,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
points.flip(bev_direction) points.flip(bev_direction)
return points return points
def convert_to(self, dst, rt_mat=None): def convert_to(self, dst, rt_mat=None, correct_yaw=False):
"""Convert self to ``dst`` mode. """Convert self to ``dst`` mode.
Args: Args:
...@@ -185,14 +185,19 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -185,14 +185,19 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
The conversion from ``src`` coordinates to ``dst`` coordinates The conversion from ``src`` coordinates to ``dst`` coordinates
usually comes along the change of sensors, e.g., from camera usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix. to LiDAR. This requires a transformation matrix.
correct_yaw (bool): If convert the yaw angle to the target
coordinate. Defaults to False.
Returns: Returns:
:obj:`BaseInstance3DBoxes`: :obj:`BaseInstance3DBoxes`:
The converted box of the same type in the ``dst`` mode. The converted box of the same type in the ``dst`` mode.
""" """
from .box_3d_mode import Box3DMode from .box_3d_mode import Box3DMode
return Box3DMode.convert( return Box3DMode.convert(
box=self, src=Box3DMode.LIDAR, dst=dst, rt_mat=rt_mat) box=self,
src=Box3DMode.LIDAR,
dst=dst,
rt_mat=rt_mat,
correct_yaw=correct_yaw)
def enlarged_box(self, extra_width): def enlarged_box(self, extra_width):
"""Enlarge the length, width and height boxes. """Enlarge the length, width and height boxes.
......
...@@ -333,3 +333,25 @@ def yaw2local(yaw, loc): ...@@ -333,3 +333,25 @@ def yaw2local(yaw, loc):
local_yaw[small_idx] += 2 * np.pi local_yaw[small_idx] += 2 * np.pi
return local_yaw return local_yaw
def get_lidar2img(cam2img, lidar2cam):
"""Get the projection matrix of lidar2img.
Args:
cam2img (torch.Tensor): A 3x3 or 4x4 projection matrix.
lidar2cam (torch.Tensor): A 3x3 or 4x4 projection matrix.
Returns:
torch.Tensor: transformation matrix with shape 4x4.
"""
if cam2img.shape == (3, 3):
temp = cam2img.new_zeros(4, 4)
temp[:3, :3] = cam2img
cam2img = temp
if lidar2cam.shape == (3, 3):
temp = lidar2cam.new_zeros(4, 4)
temp[:3, :3] = lidar2cam
lidar2cam = temp
return torch.matmul(cam2img, lidar2cam)
...@@ -58,6 +58,11 @@ def replace_ceph_backend(cfg): ...@@ -58,6 +58,11 @@ def replace_ceph_backend(cfg):
'LoadImageFromFileMono3D\'', 'LoadImageFromFileMono3D\'',
'LoadImageFromFileMono3D\',' + replace_strs) 'LoadImageFromFileMono3D\',' + replace_strs)
# replace LoadMultiViewImageFromFiles
cfg_pretty_text = cfg_pretty_text.replace(
'LoadMultiViewImageFromFiles\'',
'LoadMultiViewImageFromFiles\',' + replace_strs)
# replace LoadPointsFromFile # replace LoadPointsFromFile
cfg_pretty_text = cfg_pretty_text.replace( cfg_pretty_text = cfg_pretty_text.replace(
'LoadPointsFromFile\'', 'LoadPointsFromFile\',' + replace_strs) 'LoadPointsFromFile\'', 'LoadPointsFromFile\',' + replace_strs)
......
...@@ -70,6 +70,7 @@ def register_all_modules(init_default_scope: bool = True) -> None: ...@@ -70,6 +70,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
import mmdet3d.datasets # noqa: F401,F403 import mmdet3d.datasets # noqa: F401,F403
import mmdet3d.engine # noqa: F401,F403 import mmdet3d.engine # noqa: F401,F403
import mmdet3d.evaluation.metrics # noqa: F401,F403 import mmdet3d.evaluation.metrics # noqa: F401,F403
import mmdet3d.models # noqa: F401,F403
import mmdet3d.structures # noqa: F401,F403 import mmdet3d.structures # noqa: F401,F403
import mmdet3d.visualization # noqa: F401,F403 import mmdet3d.visualization # noqa: F401,F403
if init_default_scope: if init_default_scope:
......
...@@ -5,6 +5,7 @@ from typing import List, Optional, Union ...@@ -5,6 +5,7 @@ from typing import List, Optional, Union
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmdet3d.structures.det3d_data_sample import Det3DDataSample
from mmdet.models.task_modules.samplers import SamplingResult from mmdet.models.task_modules.samplers import SamplingResult
# Type hint of config data # Type hint of config data
...@@ -21,3 +22,4 @@ OptInstanceList = Optional[InstanceList] ...@@ -21,3 +22,4 @@ OptInstanceList = Optional[InstanceList]
SamplingResultList = List[SamplingResult] SamplingResultList = List[SamplingResult]
OptSamplingResultList = Optional[SamplingResultList] OptSamplingResultList = Optional[SamplingResultList]
SampleList = List[Det3DDataSample]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
__version__ = '1.1.0rc0' __version__ = '1.1.0rc1'
short_version = __version__ short_version = __version__
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from os import path as osp
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import mmcv import mmcv
import numpy as np import numpy as np
from matplotlib.collections import PatchCollection
from matplotlib.patches import PathPatch
from matplotlib.path import Path
from mmengine.dist import master_only from mmengine.dist import master_only
from torch import Tensor from torch import Tensor
from mmdet3d.structures.bbox_3d.box_3d_mode import Box3DMode
from mmdet.visualization import DetLocalVisualizer from mmdet.visualization import DetLocalVisualizer
try: try:
...@@ -24,10 +28,9 @@ from mmdet3d.registry import VISUALIZERS ...@@ -24,10 +28,9 @@ from mmdet3d.registry import VISUALIZERS
from mmdet3d.structures import (BaseInstance3DBoxes, CameraInstance3DBoxes, from mmdet3d.structures import (BaseInstance3DBoxes, CameraInstance3DBoxes,
Coord3DMode, DepthInstance3DBoxes, Coord3DMode, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes, Det3DDataSample, LiDARInstance3DBoxes,
PointData) PointData, points_cam2img)
from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img, from .vis_utils import (proj_camera_bbox3d_to_img, proj_depth_bbox3d_to_img,
proj_lidar_bbox3d_to_img, to_depth_mode, write_obj, proj_lidar_bbox3d_to_img, to_depth_mode)
write_oriented_bbox)
@VISUALIZERS.register_module() @VISUALIZERS.register_module()
...@@ -42,8 +45,12 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -42,8 +45,12 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
Args: Args:
name (str): Name of the instance. Defaults to 'visualizer'. name (str): Name of the instance. Defaults to 'visualizer'.
points (numpy.array, shape=[N, 3+C]): points to visualize.
image (np.ndarray, optional): the origin image to draw. The format image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None. should be RGB. Defaults to None.
pcd_mode (int): The point cloud mode (coordinates):
0 represents LiDAR, 1 represents CAMERA, 2
represents Depth. Defaults to 0.
vis_backends (list, optional): Visual backend config list. vis_backends (list, optional): Visual backend config list.
Defaults to None. Defaults to None.
save_dir (str, optional): Save file dir for all storage backends. save_dir (str, optional): Save file dir for all storage backends.
...@@ -58,7 +65,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -58,7 +65,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
Defaults to None. Defaults to None.
line_width (int, float): The linewidth of lines. line_width (int, float): The linewidth of lines.
Defaults to 3. Defaults to 3.
vis_cfg (dict): The coordinate frame config while Open3D frame_cfg (dict): The coordinate frame config while Open3D
visualization initialization. visualization initialization.
Defaults to dict(size=1, origin=[0, 0, 0]). Defaults to dict(size=1, origin=[0, 0, 0]).
alpha (int, float): The transparency of bboxes or mask. alpha (int, float): The transparency of bboxes or mask.
...@@ -87,7 +94,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -87,7 +94,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
def __init__(self, def __init__(self,
name: str = 'visualizer', name: str = 'visualizer',
points: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None, image: Optional[np.ndarray] = None,
pcd_mode: int = 0,
vis_backends: Optional[Dict] = None, vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None, save_dir: Optional[str] = None,
bbox_color: Optional[Union[str, Tuple[int]]] = None, bbox_color: Optional[Union[str, Tuple[int]]] = None,
...@@ -95,7 +104,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -95,7 +104,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
Tuple[int]]] = (200, 200, 200), Tuple[int]]] = (200, 200, 200),
mask_color: Optional[Union[str, Tuple[int]]] = None, mask_color: Optional[Union[str, Tuple[int]]] = None,
line_width: Union[int, float] = 3, line_width: Union[int, float] = 3,
vis_cfg: dict = dict(size=1, origin=[0, 0, 0]), frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
alpha: float = 0.8): alpha: float = 0.8):
super().__init__( super().__init__(
name=name, name=name,
...@@ -107,32 +116,33 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -107,32 +116,33 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
mask_color=mask_color, mask_color=mask_color,
line_width=line_width, line_width=line_width,
alpha=alpha) alpha=alpha)
self.o3d_vis = self._initialize_o3d_vis(vis_cfg) if points is not None:
self.seg_num = 0 self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg)
self.pts_seg_num = 0
def _initialize_o3d_vis(self, vis_cfg) -> tuple: def _initialize_o3d_vis(self, frame_cfg) -> tuple:
"""Build open3d vis according to vis_cfg. """Initialize open3d vis according to frame_cfg.
Args: Args:
vis_cfg (dict): The config to build open3d vis. frame_cfg (dict): The config to create coordinate frame
in open3d vis.
Returns: Returns:
tuple: build open3d vis. :obj:`o3d.visualization.Visualizer`: Created open3d vis.
""" """
# init open3d visualizer
o3d_vis = o3d.visualization.Visualizer() o3d_vis = o3d.visualization.Visualizer()
o3d_vis.create_window() o3d_vis.create_window()
# create coordinate frame # create coordinate frame
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**vis_cfg) mesh_frame = geometry.TriangleMesh.create_coordinate_frame(**frame_cfg)
o3d_vis.add_geometry(mesh_frame) o3d_vis.add_geometry(mesh_frame)
return o3d_vis return o3d_vis
@master_only @master_only
def set_points(self, def set_points(self,
points: np.ndarray, points: np.ndarray,
pcd_mode: int = 0, pcd_mode: int = 0,
vis_task: str = 'det', frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
vis_task: str = 'lidar_det',
points_color: Tuple = (0.5, 0.5, 0.5), points_color: Tuple = (0.5, 0.5, 0.5),
points_size: int = 2, points_size: int = 2,
mode: str = 'xyz') -> None: mode: str = 'xyz') -> None:
...@@ -143,9 +153,12 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -143,9 +153,12 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
points to visualize. points to visualize.
pcd_mode (int): The point cloud mode (coordinates): pcd_mode (int): The point cloud mode (coordinates):
0 represents LiDAR, 1 represents CAMERA, 2 0 represents LiDAR, 1 represents CAMERA, 2
represents Depth. represents Depth. Defaults to 0.
frame_cfg (dict): The coordinate frame config while Open3D
visualization initialization.
Defaults to dict(size=1, origin=[0, 0, 0]).
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det', 'seg'. 'lidar_det', 'multi-modality_det', 'mono_det', 'lidar_seg'.
point_color (tuple[float], optional): the color of points. point_color (tuple[float], optional): the color of points.
Default: (0.5, 0.5, 0.5). Default: (0.5, 0.5, 0.5).
points_size (int, optional): the size of points to show points_size (int, optional): the size of points to show
...@@ -156,11 +169,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -156,11 +169,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert points is not None assert points is not None
check_type('points', points, np.ndarray) check_type('points', points, np.ndarray)
if not hasattr(self, 'o3d_vis'):
self.o3d_vis = self._initialize_o3d_vis(frame_cfg)
# for now we convert points into depth mode for visualization # for now we convert points into depth mode for visualization
if pcd_mode != Coord3DMode.DEPTH: if pcd_mode != Coord3DMode.DEPTH:
points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH) points = Coord3DMode.convert(points, pcd_mode, Coord3DMode.DEPTH)
if hasattr(self, 'pcd') and vis_task != 'seg': if hasattr(self, 'pcd') and vis_task != 'lidar_seg':
self.o3d_vis.remove_geometry(self.pcd) self.o3d_vis.remove_geometry(self.pcd)
# set points size in Open3D # set points size in Open3D
...@@ -190,7 +206,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -190,7 +206,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# We draw GT / pred bboxes on the same point cloud scenes # We draw GT / pred bboxes on the same point cloud scenes
# for better detection performance comparison # for better detection performance comparison
def draw_bboxes_3d(self, def draw_bboxes_3d(self,
bboxes_3d: DepthInstance3DBoxes, bboxes_3d: BaseInstance3DBoxes,
bbox_color=(0, 1, 0), bbox_color=(0, 1, 0),
points_in_box_color=(1, 0, 0), points_in_box_color=(1, 0, 0),
rot_axis=2, rot_axis=2,
...@@ -200,7 +216,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -200,7 +216,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
bbox3d. bbox3d.
Args: Args:
bboxes_3d (:obj:`DepthInstance3DBoxes`, shape=[M, 7]): bboxes_3d (:obj:`BaseInstance3DBoxes`, shape=[M, 7]):
3d bbox (x, y, z, x_size, y_size, z_size, yaw) to visualize. 3d bbox (x, y, z, x_size, y_size, z_size, yaw) to visualize.
bbox_color (tuple[float], optional): the color of 3D bboxes. bbox_color (tuple[float], optional): the color of 3D bboxes.
Default: (0, 1, 0). Default: (0, 1, 0).
...@@ -216,7 +232,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -216,7 +232,10 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
""" """
# Before visualizing the 3D Boxes in point cloud scene # Before visualizing the 3D Boxes in point cloud scene
# we need to convert the boxes to Depth mode # we need to convert the boxes to Depth mode
check_type('bboxes', bboxes_3d, (DepthInstance3DBoxes)) check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)
if not isinstance(bboxes_3d, DepthInstance3DBoxes):
bboxes_3d = bboxes_3d.convert_to(Box3DMode.DEPTH)
# convert bboxes to numpy dtype # convert bboxes to numpy dtype
bboxes_3d = tensor2ndarray(bboxes_3d.tensor) bboxes_3d = tensor2ndarray(bboxes_3d.tensor)
...@@ -255,31 +274,191 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -255,31 +274,191 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.pcd.colors = o3d.utility.Vector3dVector(self.points_colors) self.pcd.colors = o3d.utility.Vector3dVector(self.points_colors)
self.o3d_vis.update_geometry(self.pcd) self.o3d_vis.update_geometry(self.pcd)
def set_bev_image(self,
bev_image: Optional[np.ndarray] = None,
bev_shape: Optional[int] = 900) -> None:
"""Set the bev image to draw.
Args:
bev_image (np.ndarray, optional): The bev image to draw.
Defaults to None.
bev_shape (int): The bev image shape. Defaults to 900.
"""
if bev_image is None:
bev_image = np.zeros((bev_shape, bev_shape, 3), np.uint8)
self._image = bev_image
self.width, self.height = bev_image.shape[1], bev_image.shape[0]
self._default_font_size = max(
np.sqrt(self.height * self.width) // 90, 10)
self.ax_save.cla()
self.ax_save.axis(False)
self.ax_save.imshow(bev_image, origin='lower')
# plot camera view range
x1 = np.linspace(0, self.width / 2)
x2 = np.linspace(self.width / 2, self.width)
self.ax_save.plot(
x1,
self.width / 2 - x1,
ls='--',
color='grey',
linewidth=1,
alpha=0.5)
self.ax_save.plot(
x2,
x2 - self.width / 2,
ls='--',
color='grey',
linewidth=1,
alpha=0.5)
self.ax_save.plot(
self.width / 2,
0,
marker='+',
markersize=16,
markeredgecolor='red')
# TODO: Support bev point cloud visualization
@master_only
def draw_bev_bboxes(self,
bboxes_3d: BaseInstance3DBoxes,
scale: int = 15,
edge_colors: Union[str, tuple, List[str],
List[tuple]] = 'o',
line_styles: Union[str, List[str]] = '-',
line_widths: Union[Union[int, float],
List[Union[int, float]]] = 1,
face_colors: Union[str, tuple, List[str],
List[tuple]] = 'none',
alpha: Union[int, float] = 1):
"""Draw projected 3D boxes on the image.
Args:
bboxes_3d (:obj:`BaseInstance3DBoxes`, shape=[M, 7]):
3d bbox (x, y, z, x_size, y_size, z_size, yaw) to visualize.
scale (dict): Value to scale the bev bboxes for better
visualization. Defaults to 15.
edge_colors (Union[str, tuple, List[str], List[tuple]]): The
colors of bboxes. ``colors`` can have the same length with
lines or just single value. If ``colors`` is single value, all
the lines will have the same colors. Refer to `matplotlib.
colors` for full list of formats that are accepted.
Defaults to 'o'.
line_styles (Union[str, List[str]]): The linestyle
of lines. ``line_styles`` can have the same length with
texts or just single value. If ``line_styles`` is single
value, all the lines will have the same linestyle.
Reference to
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
for more details. Defaults to '-'.
line_widths (Union[Union[int, float], List[Union[int, float]]]):
The linewidth of lines. ``line_widths`` can have
the same length with lines or just single value.
If ``line_widths`` is single value, all the lines will
have the same linewidth. Defaults to 2.
face_colors (Union[str, tuple, List[str], List[tuple]]):
The face colors. Default to 'none'.
alpha (Union[int, float]): The transparency of bboxes.
Defaults to 1.
"""
check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)
bev_bboxes = tensor2ndarray(bboxes_3d.bev)
# scale the bev bboxes for better visualization
bev_bboxes[:, :4] *= scale
ctr, w, h, theta = np.split(bev_bboxes, [2, 3, 4], axis=-1)
cos_value, sin_value = np.cos(theta), np.sin(theta)
vec1 = np.concatenate([w / 2 * cos_value, w / 2 * sin_value], axis=-1)
vec2 = np.concatenate([-h / 2 * sin_value, h / 2 * cos_value], axis=-1)
pt1 = ctr + vec1 + vec2
pt2 = ctr + vec1 - vec2
pt3 = ctr - vec1 - vec2
pt4 = ctr - vec1 + vec2
poly = np.stack([pt1, pt2, pt3, pt4], axis=-2)
# move the object along x-axis
poly[:, :, 0] += self.width / 2
poly = [p for p in poly]
return self.draw_polygons(
poly,
alpha=alpha,
edge_colors=edge_colors,
line_styles=line_styles,
line_widths=line_widths,
face_colors=face_colors)
@master_only
def draw_points_on_image(
self,
points: Union[np.ndarray, Tensor],
pts2img: np.ndarray,
sizes: Optional[Union[np.ndarray, Tensor, int]] = 10) -> None:
"""Draw projected points on the image.
Args:
positions (Union[np.ndarray, torch.Tensor]): Positions to draw.
pts2imgs (np,ndarray): The transformatino matrix from the
coordinate of point cloud to image plane.
sizes (Optional[Union[np.ndarray, torch.Tensor, int]]): The
marker size. Default to 10.
"""
check_type('points', points, (np.ndarray, Tensor))
points = tensor2ndarray(points)
assert self._image is not None, 'Please set image using `set_image`'
projected_points = points_cam2img(points, pts2img, with_depth=True)
depths = projected_points[:, 2]
colors = (depths % 20) / 20
# use colormap to obtain the render color
color_map = plt.get_cmap('jet')
self.ax_save.scatter(
projected_points[:, 0],
projected_points[:, 1],
c=colors,
cmap=color_map,
s=sizes,
alpha=0.5,
edgecolors='none')
# TODO: set bbox color according to palette # TODO: set bbox color according to palette
@master_only
def draw_proj_bboxes_3d(self, def draw_proj_bboxes_3d(self,
bboxes_3d: BaseInstance3DBoxes, bboxes_3d: BaseInstance3DBoxes,
input_meta: dict, input_meta: dict,
bbox_color: Tuple[float] = 'b', edge_colors: Union[str, tuple, List[str],
List[tuple]] = 'royalblue',
line_styles: Union[str, List[str]] = '-', line_styles: Union[str, List[str]] = '-',
line_widths: Union[Union[int, float], line_widths: Union[Union[int, float],
List[Union[int, float]]] = 1): List[Union[int, float]]] = 2,
face_colors: Union[str, tuple, List[str],
List[tuple]] = 'royalblue',
alpha: Union[int, float] = 0.4):
"""Draw projected 3D boxes on the image. """Draw projected 3D boxes on the image.
Args: Args:
bbox3d (:obj:`BaseInstance3DBoxes`, shape=[M, 7]): bbox3d (:obj:`BaseInstance3DBoxes`, shape=[M, 7]):
3d bbox (x, y, z, x_size, y_size, z_size, yaw) to visualize. 3d bbox (x, y, z, x_size, y_size, z_size, yaw) to visualize.
input_meta (dict): Input meta information. input_meta (dict): Input meta information.
bbox_color (tuple[float], optional): the color of bbox. edge_colors (Union[str, tuple, List[str], List[tuple]]): The
Default: (0, 1, 0). colors of bboxes. ``colors`` can have the same length with
lines or just single value. If ``colors`` is single value, all
the lines will have the same colors. Refer to `matplotlib.
colors` for full list of formats that are accepted.
Defaults to 'royalblue'.
line_styles (Union[str, List[str]]): The linestyle line_styles (Union[str, List[str]]): The linestyle
of lines. ``line_styles`` can have the same length with of lines. ``line_styles`` can have the same length with
texts or just single value. If ``line_styles`` is single texts or just single value. If ``line_styles`` is single
value, all the lines will have the same linestyle. value, all the lines will have the same linestyle.
Reference to
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
for more details. Defaults to '-'.
line_widths (Union[Union[int, float], List[Union[int, float]]]): line_widths (Union[Union[int, float], List[Union[int, float]]]):
The linewidth of lines. ``line_widths`` can have The linewidth of lines. ``line_widths`` can have
the same length with lines or just single value. the same length with lines or just single value.
If ``line_widths`` is single value, all the lines will If ``line_widths`` is single value, all the lines will
have the same linewidth. Defaults to 2. have the same linewidth. Defaults to 2.
face_colors (Union[str, tuple, List[str], List[tuple]]):
The face colors. Default to 'royalblue'.
alpha (Union[int, float]): The transparency of bboxes.
Defaults to 0.4.
""" """
check_type('bboxes', bboxes_3d, BaseInstance3DBoxes) check_type('bboxes', bboxes_3d, BaseInstance3DBoxes)
...@@ -293,27 +472,39 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -293,27 +472,39 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
raise NotImplementedError('unsupported box type!') raise NotImplementedError('unsupported box type!')
# (num_bboxes_3d, 8, 2) corners_2d = proj_bbox3d_to_img(bboxes_3d, input_meta)
proj_bboxes_3d = proj_bbox3d_to_img(bboxes_3d, input_meta)
num_bboxes_3d = proj_bboxes_3d.shape[0] lines_verts_idx = [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 5, 1, 2, 6]
lines_verts = corners_2d[:, lines_verts_idx, :]
line_indices = ((0, 1), (0, 3), (0, 4), (1, 2), (1, 5), (3, 2), (3, 7), front_polys = corners_2d[:, 4:, :]
(4, 5), (4, 7), (2, 6), (5, 6), (6, 7)) codes = [Path.LINETO] * lines_verts.shape[1]
codes[0] = Path.MOVETO
# TODO: assign each projected 3d bboxes color pathpatches = []
# according to pred / gt class. for i in range(len(corners_2d)):
for i in range(num_bboxes_3d): verts = lines_verts[i]
x_datas = [] pth = Path(verts, codes)
y_datas = [] pathpatches.append(PathPatch(pth))
corners = proj_bboxes_3d[i].astype(np.int) # (8, 2)
for start, end in line_indices: p = PatchCollection(
x_datas.append([corners[start][0], corners[end][0]]) pathpatches,
y_datas.append([corners[start][1], corners[end][1]]) facecolors='none',
x_datas = np.array(x_datas) edgecolors=edge_colors,
y_datas = np.array(y_datas) linewidths=line_widths,
self.draw_lines(x_datas, y_datas, bbox_color, line_styles, linestyles=line_styles)
line_widths)
self.ax_save.add_collection(p)
# draw a mask on the front of project bboxes
front_polys = [front_poly for front_poly in front_polys]
return self.draw_polygons(
front_polys,
alpha=alpha,
edge_colors=edge_colors,
line_styles=line_styles,
line_widths=line_widths,
face_colors=face_colors)
@master_only
def draw_seg_mask(self, seg_mask_colors: np.array): def draw_seg_mask(self, seg_mask_colors: np.array):
"""Add segmentation mask to visualizer via per-point colorization. """Add segmentation mask to visualizer via per-point colorization.
...@@ -325,15 +516,16 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -325,15 +516,16 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
# we can't draw the colors on existing points # we can't draw the colors on existing points
# in case gt and pred mask would overlap # in case gt and pred mask would overlap
# instead we set a large offset along x-axis for each seg mask # instead we set a large offset along x-axis for each seg mask
self.seg_num += 1 self.pts_seg_num += 1
offset = (np.array(self.pcd.points).max(0) - offset = (np.array(self.pcd.points).max(0) -
np.array(self.pcd.points).min(0))[0] * 1.2 * self.seg_num np.array(self.pcd.points).min(0))[0] * 1.2 * self.pts_seg_num
mesh_frame = geometry.TriangleMesh.create_coordinate_frame( mesh_frame = geometry.TriangleMesh.create_coordinate_frame(
size=1, origin=[offset, 0, 0]) # create coordinate frame for seg size=1, origin=[offset, 0, 0]) # create coordinate frame for seg
self.o3d_vis.add_geometry(mesh_frame) self.o3d_vis.add_geometry(mesh_frame)
seg_points = copy.deepcopy(seg_mask_colors) seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset seg_points[:, 0] += offset
self.set_points(seg_points, vis_task='seg', pcd_mode=2, mode='xyzrgb') self.set_points(
seg_points, vis_task='lidar_seg', pcd_mode=2, mode='xyzrgb')
def _draw_instances_3d(self, data_input: dict, instances: InstanceData, def _draw_instances_3d(self, data_input: dict, instances: InstanceData,
input_meta: dict, vis_task: str, input_meta: dict, vis_task: str,
...@@ -346,7 +538,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -346,7 +538,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
instance-level annotations or predictions. instance-level annotations or predictions.
metainfo (dict): Meta information. metainfo (dict): Meta information.
vis_task (str): Visualiztion task, it includes: vis_task (str): Visualiztion task, it includes:
'det', 'multi_modality-det', 'mono-det'. 'lidar_det', 'multi-modality_det', 'mono_det'.
Returns: Returns:
dict: the drawn point cloud and image which channel is RGB. dict: the drawn point cloud and image which channel is RGB.
...@@ -356,7 +548,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -356,7 +548,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_3d = dict() data_3d = dict()
if vis_task in ['det', 'multi_modality-det']: if vis_task in ['lidar_det', 'multi-modality_det']:
assert 'points' in data_input assert 'points' in data_input
points = data_input['points'] points = data_input['points']
check_type('points', points, (np.ndarray, Tensor)) check_type('points', points, (np.ndarray, Tensor))
...@@ -373,7 +565,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -373,7 +565,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor) data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
data_3d['points'] = points data_3d['points'] = points
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono_det', 'multi-modality_det']:
assert 'img' in data_input assert 'img' in data_input
img = data_input['img'] img = data_input['img']
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
...@@ -381,6 +573,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -381,6 +573,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
self.set_image(img) self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta) self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'):
centers_2d = instances.centers_2d
self.draw_points(centers_2d)
drawn_img = self.get_image() drawn_img = self.get_image()
data_3d['img'] = drawn_img data_3d['img'] = drawn_img
...@@ -419,27 +614,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -419,27 +614,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pts_color = palette[pts_sem_seg] pts_color = palette[pts_sem_seg]
seg_color = np.concatenate([points[:, :3], pts_color], axis=1) seg_color = np.concatenate([points[:, :3], pts_color], axis=1)
self.set_points(points, pcd_mode=2, vis_task='seg') self.set_points(points, pcd_mode=2, vis_task='lidar_seg')
self.draw_seg_mask(seg_color) self.draw_seg_mask(seg_color)
seg_data_3d = dict(points=points, seg_color=seg_color)
return seg_data_3d
@master_only @master_only
def show(self, def show(self,
vis_task: str = None, save_path: Optional[str] = None,
out_file: str = None,
drawn_img_3d: Optional[np.ndarray] = None, drawn_img_3d: Optional[np.ndarray] = None,
drawn_img: Optional[np.ndarray] = None, drawn_img: Optional[np.ndarray] = None,
win_name: str = 'image', win_name: str = 'image',
wait_time: int = 0, wait_time: int = 0,
continue_key=' ') -> None: continue_key=' ') -> None:
"""Show the drawn image. """Show the drawn point cloud/image.
Args: Args:
vis_task (str): Visualiztion task, it includes: save_path (str, optional): Path to save open3d visualized results.
'det', 'multi_modality-det', 'mono-det', 'seg'. Default: None.
out_file (str): Output file path.
drawn_img (np.ndarray, optional): The image to show. If drawn_img drawn_img (np.ndarray, optional): The image to show. If drawn_img
is None, it will show the image got by Visualizer. Defaults is None, it will show the image got by Visualizer. Defaults
to None. to None.
...@@ -449,16 +639,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -449,16 +639,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
continue_key (str): The key for users to continue. Defaults to continue_key (str): The key for users to continue. Defaults to
the space key. the space key.
""" """
if vis_task in ['det', 'multi_modality-det', 'seg']: if hasattr(self, 'o3d_vis'):
self.o3d_vis.run() self.o3d_vis.run()
if out_file is not None: if save_path is not None:
self.o3d_vis.capture_screen_image(out_file + '.png') self.o3d_vis.capture_screen_image(save_path)
self.o3d_vis.destroy_window() self.o3d_vis.destroy_window()
if vis_task in ['mono-det', 'multi_modality-det']: if hasattr(self, '_image'):
super().show(drawn_img_3d, win_name, wait_time, continue_key) if drawn_img_3d is None:
super().show(drawn_img_3d, win_name, wait_time, continue_key)
if drawn_img is not None:
super().show(drawn_img, win_name, wait_time, continue_key) super().show(drawn_img, win_name, wait_time, continue_key)
# TODO: Support Visualize the 3D results from image and point cloud # TODO: Support Visualize the 3D results from image and point cloud
...@@ -473,7 +662,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -473,7 +662,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
show: bool = False, show: bool = False,
wait_time: float = 0, wait_time: float = 0,
out_file: Optional[str] = None, out_file: Optional[str] = None,
vis_task: str = 'mono-det', save_path: Optional[str] = None,
vis_task: str = 'mono_det',
pred_score_thr: float = 0.3, pred_score_thr: float = 0.3,
step: int = 0) -> None: step: int = 0) -> None:
"""Draw datasample and save to all backends. """Draw datasample and save to all backends.
...@@ -501,7 +691,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -501,7 +691,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
image. Default to False. image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None. out_file (str): Path to output file. Defaults to None.
vis-task (str): Visualization task. Defaults to 'mono-det'. save_path (str, optional): Path to save open3d visualized results.
Default: None.
vis-task (str): Visualization task. Defaults to 'mono_det'.
pred_score_thr (float): The threshold to visualize the bboxes pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3. and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0. step (int): Global step value to record. Defaults to 0.
...@@ -513,8 +705,6 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -513,8 +705,6 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
gt_data_3d = None gt_data_3d = None
pred_data_3d = None pred_data_3d = None
gt_seg_data_3d = None
pred_seg_data_3d = None
gt_img_data = None gt_img_data = None
pred_img_data = None pred_img_data = None
...@@ -524,23 +714,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -524,23 +714,22 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_input, data_sample.gt_instances_3d, data_input, data_sample.gt_instances_3d,
data_sample.metainfo, vis_task, palette) data_sample.metainfo, vis_task, palette)
if 'gt_instances' in data_sample: if 'gt_instances' in data_sample:
assert 'img' in data_input if len(data_sample.gt_instances) > 0:
if isinstance(data_input['img'], Tensor): assert 'img' in data_input
img = data_input['img'].permute(1, 2, 0).numpy() if isinstance(data_input['img'], Tensor):
img = img[..., [2, 1, 0]] # bgr to rgb img = data_input['img'].permute(1, 2, 0).numpy()
gt_img_data = self._draw_instances(img, img = img[..., [2, 1, 0]] # bgr to rgb
data_sample.gt_instances, gt_img_data = self._draw_instances(
classes, palette) img, data_sample.gt_instances, classes, palette)
if 'gt_pts_seg' in data_sample: if 'gt_pts_seg' in data_sample and vis_task == 'seg':
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
'segmentation results.' 'segmentation results.'
assert 'points' in data_input assert 'points' in data_input
gt_seg_data_3d = \ self._draw_pts_sem_seg(data_input['points'],
self._draw_pts_sem_seg(data_input['points'], data_sample.pred_pts_seg, palette,
data_sample.pred_pts_seg, ignore_index)
palette, ignore_index)
if draw_pred and data_sample is not None: if draw_pred and data_sample is not None:
if 'pred_instances_3d' in data_sample: if 'pred_instances_3d' in data_sample:
...@@ -563,19 +752,18 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -563,19 +752,18 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
pred_img_data = self._draw_instances( pred_img_data = self._draw_instances(
img, pred_instances, classes, palette) img, pred_instances, classes, palette)
if 'pred_pts_seg' in data_sample: if 'pred_pts_seg' in data_sample and vis_task == 'lidar_seg':
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
'segmentation results.' 'segmentation results.'
assert 'points' in data_input assert 'points' in data_input
pred_seg_data_3d = \ self._draw_pts_sem_seg(data_input['points'],
self._draw_pts_sem_seg(data_input['points'], data_sample.pred_pts_seg, palette,
data_sample.pred_pts_seg, ignore_index)
palette, ignore_index)
# monocular 3d object detection image # monocular 3d object detection image
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono_det', 'multi-modality_det']:
if gt_data_3d is not None and pred_data_3d is not None: if gt_data_3d is not None and pred_data_3d is not None:
drawn_img_3d = np.concatenate( drawn_img_3d = np.concatenate(
(gt_data_3d['img'], pred_data_3d['img']), axis=1) (gt_data_3d['img'], pred_data_3d['img']), axis=1)
...@@ -599,7 +787,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -599,7 +787,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if show: if show:
self.show( self.show(
vis_task, vis_task,
out_file, save_path,
drawn_img_3d, drawn_img_3d,
drawn_img, drawn_img,
win_name=name, win_name=name,
...@@ -607,29 +795,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -607,29 +795,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
if out_file is not None: if out_file is not None:
if drawn_img_3d is not None: if drawn_img_3d is not None:
mmcv.imwrite(drawn_img_3d[..., ::-1], out_file + '.jpg') mmcv.imwrite(drawn_img_3d[..., ::-1], out_file)
if drawn_img is not None: if drawn_img is not None:
mmcv.imwrite(drawn_img[..., ::-1], out_file + '.jpg') mmcv.imwrite(drawn_img[..., ::-1], out_file)
if gt_data_3d is not None:
write_obj(gt_data_3d['points'],
osp.join(out_file, 'points.obj'))
write_oriented_bbox(gt_data_3d['bboxes_3d'],
osp.join(out_file, 'gt_bbox.obj'))
if pred_data_3d is not None:
if 'points' in pred_data_3d:
write_obj(pred_data_3d['points'],
osp.join(out_file, 'points.obj'))
write_oriented_bbox(pred_data_3d['bboxes_3d'],
osp.join(out_file, 'pred_bbox.obj'))
if gt_seg_data_3d is not None:
write_obj(gt_seg_data_3d['points'],
osp.join(out_file, 'points.obj'))
write_obj(gt_seg_data_3d['seg_color'],
osp.join(out_file, 'gt_seg.obj'))
if pred_seg_data_3d is not None:
write_obj(pred_seg_data_3d['points'],
osp.join(out_file, 'points.obj'))
write_obj(pred_seg_data_3d['seg_color'],
osp.join(out_file, 'pred_seg.obj'))
else: else:
self.add_image(name, drawn_img_3d, step) self.add_image(name, drawn_img_3d, step)
spconv black==20.8b1 # be compatible with typing-extensions 3.7.4
waymo-open-dataset-tf-2-1-0==1.2.0 typing-extensions==3.7.4 # required by tensorflow<=2.6
waymo-open-dataset-tf-2-6-0 # requires python>=3.7
lyft_dataset_sdk lyft_dataset_sdk
networkx>=2.2,<2.3 networkx>=2.5
numba==0.53.0 numba==0.53.0
numpy numpy
nuscenes-devkit nuscenes-devkit
......
...@@ -13,4 +13,4 @@ no_lines_before = STDLIB,LOCALFOLDER ...@@ -13,4 +13,4 @@ no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
[codespell] [codespell]
ignore-words-list = ans,refridgerator,crate,hist,formating,dout,wan,nd,fo,avod,AVOD ignore-words-list = ans,refridgerator,crate,hist,formating,dout,wan,nd,fo,avod,AVOD,warmup
...@@ -223,6 +223,7 @@ if __name__ == '__main__': ...@@ -223,6 +223,7 @@ if __name__ == '__main__':
'tests': parse_requirements('requirements/tests.txt'), 'tests': parse_requirements('requirements/tests.txt'),
'build': parse_requirements('requirements/build.txt'), 'build': parse_requirements('requirements/build.txt'),
'optional': parse_requirements('requirements/optional.txt'), 'optional': parse_requirements('requirements/optional.txt'),
'mim': parse_requirements('requirements/mminstall.txt'),
}, },
cmdclass={'build_ext': BuildExtension}, cmdclass={'build_ext': BuildExtension},
zip_safe=False) zip_safe=False)
...@@ -28,7 +28,7 @@ def _generate_kitti_dataset_config(): ...@@ -28,7 +28,7 @@ def _generate_kitti_dataset_config():
gt_instances_3d = InstanceData() gt_instances_3d = InstanceData()
gt_instances_3d.labels_3d = info['gt_labels_3d'] gt_instances_3d.labels_3d = info['gt_labels_3d']
data_sample.gt_instances_3d = gt_instances_3d data_sample.gt_instances_3d = gt_instances_3d
info['data_sample'] = data_sample info['data_samples'] = data_sample
return info return info
pipeline = [ pipeline = [
...@@ -82,9 +82,9 @@ def test_getitem(): ...@@ -82,9 +82,9 @@ def test_getitem():
assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(), assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
torch.tensor(7.2650)) torch.tensor(7.2650))
assert 'centers_2d' in ann_info assert 'centers_2d' in ann_info
assert ann_info['centers_2d'].dtype == np.float64 assert ann_info['centers_2d'].dtype == np.float32
assert 'depths' in ann_info assert 'depths' in ann_info
assert ann_info['depths'].dtype == np.float64 assert ann_info['depths'].dtype == np.float32
car_kitti_dataset = KittiDataset( car_kitti_dataset = KittiDataset(
data_root, data_root,
......
...@@ -21,12 +21,12 @@ def _generate_nus_dataset_config(): ...@@ -21,12 +21,12 @@ def _generate_nus_dataset_config():
class Identity(BaseTransform): class Identity(BaseTransform):
def transform(self, info): def transform(self, info):
packed_input = dict(data_sample=Det3DDataSample()) packed_input = dict(data_samples=Det3DDataSample())
if 'ann_info' in info: if 'ann_info' in info:
packed_input['data_sample'].gt_instances_3d = InstanceData(
)
packed_input[ packed_input[
'data_sample'].gt_instances_3d.labels_3d = info[ 'data_samples'].gt_instances_3d = InstanceData()
packed_input[
'data_samples'].gt_instances_3d.labels_3d = info[
'ann_info']['gt_labels_3d'] 'ann_info']['gt_labels_3d']
return packed_input return packed_input
......
...@@ -21,12 +21,12 @@ def _generate_nus_dataset_config(): ...@@ -21,12 +21,12 @@ def _generate_nus_dataset_config():
class Identity(BaseTransform): class Identity(BaseTransform):
def transform(self, info): def transform(self, info):
packed_input = dict(data_sample=Det3DDataSample()) packed_input = dict(data_samples=Det3DDataSample())
if 'ann_info' in info: if 'ann_info' in info:
packed_input['data_sample'].gt_instances_3d = InstanceData(
)
packed_input[ packed_input[
'data_sample'].gt_instances_3d.labels_3d = info[ 'data_samples'].gt_instances_3d = InstanceData()
packed_input[
'data_samples'].gt_instances_3d.labels_3d = info[
'ann_info']['gt_labels_3d'] 'ann_info']['gt_labels_3d']
return packed_input return packed_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment