"tests/vscode:/vscode.git/clone" did not exist on "c04ef89597f00d3ca4b4e7414dfb7f9ceba2e4f3"
Unverified Commit b50e2035 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Enhance] Support different colors for different classes in visualization (#2500)

* support different colors

* use tuple palette
parent d99dbce7
......@@ -14,7 +14,7 @@ from mmengine.dataset import Compose, pseudo_collate
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmdet3d.registry import MODELS
from mmdet3d.registry import DATASETS, MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
......@@ -38,6 +38,7 @@ def convert_SyncBN(config):
def init_model(config: Union[str, Path, Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
palette: str = 'none',
cfg_options: Optional[dict] = None):
"""Initialize a model from config file, which could be a 3D detector or a
3D segmentor.
......@@ -87,6 +88,20 @@ def init_model(config: Union[str, Path, Config],
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
test_dataset_cfg = deepcopy(config.test_dataloader.dataset)
# lazy init. We only need the metainfo.
test_dataset_cfg['lazy_init'] = True
metainfo = DATASETS.build(test_dataset_cfg).metainfo
cfg_palette = metainfo.get('palette', None)
if cfg_palette is not None:
model.dataset_meta['palette'] = cfg_palette
else:
if 'palette' not in model.dataset_meta:
warnings.warn(
'palette does not exist, random is used by default. '
'You can also set the palette to customize.')
model.dataset_meta['palette'] = 'random'
model.cfg = config # save the config in the model for convenience
if device != 'cpu':
torch.cuda.set_device(device)
......
......@@ -139,20 +139,21 @@ class Det3DDataset(BaseDataset):
self.metainfo['box_type_3d'] = box_type_3d
self.metainfo['label_mapping'] = self.label_mapping
# used for showing variation of the number of instances before and
# after through the pipeline
self.show_ins_var = show_ins_var
# show statistics of this dataset
print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current')
content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items():
content_show.append([cat_name, num])
table = AsciiTable(content_show)
print_log(
f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501
'current')
if not kwargs.get('lazy_init', False):
# used for showing variation of the number of instances before and
# after through the pipeline
self.show_ins_var = show_ins_var
# show statistics of this dataset
print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current')
content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items():
content_show.append([cat_name, num])
table = AsciiTable(content_show)
print_log(
f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501
'current')
def _remove_dontcare(self, ann_info: dict) -> dict:
"""Remove annotations that do not need to be cared.
......
......@@ -54,7 +54,9 @@ class KittiDataset(Det3DDataset):
# TODO: use full classes of kitti
METAINFO = {
'classes': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc')
'Person_sitting', 'Tram', 'Misc'),
'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
(197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255)]
}
def __init__(self,
......
......@@ -44,7 +44,10 @@ class LyftDataset(Det3DDataset):
METAINFO = {
'classes':
('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal')
'motorcycle', 'bicycle', 'pedestrian', 'animal'),
'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
(197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255),
(153, 69, 1)]
}
def __init__(self,
......
......@@ -60,7 +60,19 @@ class NuScenesDataset(Det3DDataset):
('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
'version':
'v1.0-trainval'
'v1.0-trainval',
'palette': [
(255, 158, 0), # Orange
(255, 99, 71), # Tomato
(255, 140, 0), # Darkorange
(255, 127, 80), # Coral
(233, 150, 70), # Darksalmon
(220, 20, 60), # Crimson
(255, 61, 99), # Red
(0, 0, 230), # Blue
(47, 79, 79), # Darkslategrey
(112, 128, 144), # Slategrey
]
}
def __init__(self,
......
......@@ -52,7 +52,10 @@ class S3DISDataset(Det3DDataset):
'classes': ('table', 'chair', 'sofa', 'bookcase', 'board'),
# the valid ids of segmentation annotations
'seg_valid_class_ids': (7, 8, 9, 10, 11),
'seg_all_class_ids': tuple(range(1, 14)) # possibly with 'stair' class
'seg_all_class_ids':
tuple(range(1, 14)), # possibly with 'stair' class
'palette': [(170, 120, 200), (255, 0, 0), (200, 100, 100),
(10, 200, 100), (200, 200, 200)]
}
def __init__(self,
......
......@@ -57,7 +57,13 @@ class ScanNetDataset(Det3DDataset):
'seg_valid_class_ids':
(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
'seg_all_class_ids':
tuple(range(1, 41))
tuple(range(1, 41)),
'palette': [(31, 119, 180), (255, 187, 120), (188, 189, 34),
(140, 86, 75), (255, 152, 150), (214, 39, 40),
(197, 176, 213), (148, 103, 189), (196, 156, 148),
(23, 190, 207), (247, 182, 210), (219, 219, 141),
(255, 127, 14), (158, 218, 229), (44, 160, 44),
(112, 128, 144), (227, 119, 194), (82, 84, 163)]
}
def __init__(self,
......
......@@ -119,12 +119,13 @@ class Seg3DDataset(BaseDataset):
**kwargs)
self.metainfo['seg_label_mapping'] = self.seg_label_mapping
self.scene_idxs = self.get_scene_idxs(scene_idxs)
self.data_list = [self.data_list[i] for i in self.scene_idxs]
if not kwargs.get('lazy_init', False):
self.scene_idxs = self.get_scene_idxs(scene_idxs)
self.data_list = [self.data_list[i] for i in self.scene_idxs]
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
# set group flag for the sampler
if not self.test_mode:
self._set_group_flag()
def get_label_mapping(self,
new_classes: Optional[Sequence] = None) -> tuple:
......
......@@ -47,7 +47,11 @@ class SUNRGBDDataset(Det3DDataset):
"""
METAINFO = {
'classes': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub')
'dresser', 'night_stand', 'bookshelf', 'bathtub'),
'palette': [(255, 187, 120), (255, 152, 150), (140, 86, 75),
(188, 189, 34), (44, 160, 44), (247, 182, 210),
(196, 156, 148), (23, 190, 207), (148, 103, 189),
(227, 119, 194)]
}
def __init__(self,
......
......@@ -68,7 +68,14 @@ class WaymoDataset(KittiDataset):
load_interval (int): load frame interval. Defaults to 1.
max_sweeps (int): max sweep for each frame. Defaults to 0.
"""
METAINFO = {'classes': ('Car', 'Pedestrian', 'Cyclist')}
METAINFO = {
'classes': ('Car', 'Pedestrian', 'Cyclist'),
'palette': [
(0, 120, 255), # Waymo Blue
(0, 232, 157), # Waymo Green
(255, 205, 85) # Amber
]
}
def __init__(self,
data_root: str,
......
......@@ -10,11 +10,12 @@ import numpy as np
from matplotlib.collections import PatchCollection
from matplotlib.patches import PathPatch
from matplotlib.path import Path
from mmdet.visualization import DetLocalVisualizer
from mmdet.visualization import DetLocalVisualizer, get_palette
from mmengine.dist import master_only
from mmengine.structures import InstanceData
from mmengine.visualization import Visualizer as MMENGINE_Visualizer
from mmengine.visualization.utils import check_type, tensor2ndarray
from mmengine.visualization.utils import (check_type, color_val_matplotlib,
tensor2ndarray)
from torch import Tensor
from mmdet3d.registry import VISUALIZERS
......@@ -173,7 +174,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pcd_mode: int = 0,
vis_mode: str = 'replace',
frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
points_color: Tuple[float] = (1, 1, 1),
points_color: Tuple[float] = (0.8, 0.8, 0.8),
points_size: int = 2,
mode: str = 'xyz') -> None:
"""Set the point cloud to draw.
......@@ -295,7 +296,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
line_set = geometry.LineSet.create_from_oriented_bounding_box(
box3d)
line_set.paint_uniform_color(bbox_color)
line_set.paint_uniform_color(np.array(bbox_color[i]) / 255.)
# draw bboxes on visualizer
self.o3d_vis.add_geometry(line_set)
......@@ -509,6 +510,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else:
raise NotImplementedError('unsupported box type!')
edge_colors_norm = color_val_matplotlib(edge_colors)
corners_2d = proj_bbox3d_to_img(bboxes_3d, input_meta)
if img_size is not None:
# Filter out the bbox where half of stuff is outside the image.
......@@ -518,6 +521,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
(corners_2d[..., 1] >= 0) & (corners_2d[..., 1] <= img_size[1]) # noqa: E501
valid_bbox_idx = valid_point_idx.sum(axis=-1) >= 4
corners_2d = corners_2d[valid_bbox_idx]
filter_edge_colors = []
filter_edge_colors_norm = []
for i, color in enumerate(edge_colors):
if valid_bbox_idx[i]:
filter_edge_colors.append(color)
filter_edge_colors_norm.append(edge_colors_norm[i])
edge_colors = filter_edge_colors
edge_colors_norm = filter_edge_colors_norm
lines_verts_idx = [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 5, 1, 2, 6]
lines_verts = corners_2d[:, lines_verts_idx, :]
......@@ -533,7 +544,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
p = PatchCollection(
pathpatches,
facecolors='none',
edgecolors=edge_colors,
edgecolors=edge_colors_norm,
linewidths=line_widths,
linestyles=line_styles)
......@@ -547,7 +558,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
edge_colors=edge_colors,
line_styles=line_styles,
line_widths=line_widths,
face_colors=face_colors)
face_colors=edge_colors)
@master_only
def draw_seg_mask(self, seg_mask_colors: np.ndarray) -> None:
......@@ -598,6 +609,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
return None
bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes
labels_3d = instances.labels_3d
data_3d = dict()
......@@ -612,8 +624,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else:
bboxes_3d_depth = bboxes_3d.clone()
max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.set_points(points, pcd_mode=2)
self.draw_bboxes_3d(bboxes_3d_depth)
self.draw_bboxes_3d(bboxes_3d_depth, bbox_color=colors)
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
data_3d['points'] = points
......@@ -646,10 +664,19 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
single_img_meta[key] = meta[i]
else:
single_img_meta[key] = meta
max_label = int(
max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.draw_proj_bboxes_3d(
bboxes_3d,
single_img_meta,
img_size=single_img.shape[:2][::-1])
img_size=single_img.shape[:2][::-1],
edge_colors=colors)
if vis_task == 'mono_det' and hasattr(
instances, 'centers_2d'):
centers_2d = instances.centers_2d
......@@ -668,7 +695,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img.permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb
self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.draw_proj_bboxes_3d(
bboxes_3d, input_meta, edge_colors=colors)
if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'):
centers_2d = instances.centers_2d
self.draw_points(centers_2d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment