Unverified Commit b50e2035 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Enhance] Support different colors for different classes in visualization (#2500)

* support different colors

* use tuple palette
parent d99dbce7
...@@ -14,7 +14,7 @@ from mmengine.dataset import Compose, pseudo_collate ...@@ -14,7 +14,7 @@ from mmengine.dataset import Compose, pseudo_collate
from mmengine.registry import init_default_scope from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
from mmdet3d.registry import MODELS from mmdet3d.registry import DATASETS, MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
...@@ -38,6 +38,7 @@ def convert_SyncBN(config): ...@@ -38,6 +38,7 @@ def convert_SyncBN(config):
def init_model(config: Union[str, Path, Config], def init_model(config: Union[str, Path, Config],
checkpoint: Optional[str] = None, checkpoint: Optional[str] = None,
device: str = 'cuda:0', device: str = 'cuda:0',
palette: str = 'none',
cfg_options: Optional[dict] = None): cfg_options: Optional[dict] = None):
"""Initialize a model from config file, which could be a 3D detector or a """Initialize a model from config file, which could be a 3D detector or a
3D segmentor. 3D segmentor.
...@@ -87,6 +88,20 @@ def init_model(config: Union[str, Path, Config], ...@@ -87,6 +88,20 @@ def init_model(config: Union[str, Path, Config],
if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE'] model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
test_dataset_cfg = deepcopy(config.test_dataloader.dataset)
# lazy init. We only need the metainfo.
test_dataset_cfg['lazy_init'] = True
metainfo = DATASETS.build(test_dataset_cfg).metainfo
cfg_palette = metainfo.get('palette', None)
if cfg_palette is not None:
model.dataset_meta['palette'] = cfg_palette
else:
if 'palette' not in model.dataset_meta:
warnings.warn(
'palette does not exist, random is used by default. '
'You can also set the palette to customize.')
model.dataset_meta['palette'] = 'random'
model.cfg = config # save the config in the model for convenience model.cfg = config # save the config in the model for convenience
if device != 'cpu': if device != 'cpu':
torch.cuda.set_device(device) torch.cuda.set_device(device)
......
...@@ -139,20 +139,21 @@ class Det3DDataset(BaseDataset): ...@@ -139,20 +139,21 @@ class Det3DDataset(BaseDataset):
self.metainfo['box_type_3d'] = box_type_3d self.metainfo['box_type_3d'] = box_type_3d
self.metainfo['label_mapping'] = self.label_mapping self.metainfo['label_mapping'] = self.label_mapping
# used for showing variation of the number of instances before and if not kwargs.get('lazy_init', False):
# after through the pipeline # used for showing variation of the number of instances before and
self.show_ins_var = show_ins_var # after through the pipeline
self.show_ins_var = show_ins_var
# show statistics of this dataset
print_log('-' * 30, 'current') # show statistics of this dataset
print_log(f'The length of the dataset: {len(self)}', 'current') print_log('-' * 30, 'current')
content_show = [['category', 'number']] print_log(f'The length of the dataset: {len(self)}', 'current')
for cat_name, num in self.num_ins_per_cat.items(): content_show = [['category', 'number']]
content_show.append([cat_name, num]) for cat_name, num in self.num_ins_per_cat.items():
table = AsciiTable(content_show) content_show.append([cat_name, num])
print_log( table = AsciiTable(content_show)
f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501 print_log(
'current') f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501
'current')
def _remove_dontcare(self, ann_info: dict) -> dict: def _remove_dontcare(self, ann_info: dict) -> dict:
"""Remove annotations that do not need to be cared. """Remove annotations that do not need to be cared.
......
...@@ -54,7 +54,9 @@ class KittiDataset(Det3DDataset): ...@@ -54,7 +54,9 @@ class KittiDataset(Det3DDataset):
# TODO: use full classes of kitti # TODO: use full classes of kitti
METAINFO = { METAINFO = {
'classes': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck', 'classes': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc') 'Person_sitting', 'Tram', 'Misc'),
'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
(197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255)]
} }
def __init__(self, def __init__(self,
......
...@@ -44,7 +44,10 @@ class LyftDataset(Det3DDataset): ...@@ -44,7 +44,10 @@ class LyftDataset(Det3DDataset):
METAINFO = { METAINFO = {
'classes': 'classes':
('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle', ('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal') 'motorcycle', 'bicycle', 'pedestrian', 'animal'),
'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
(197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255),
(153, 69, 1)]
} }
def __init__(self, def __init__(self,
......
...@@ -60,7 +60,19 @@ class NuScenesDataset(Det3DDataset): ...@@ -60,7 +60,19 @@ class NuScenesDataset(Det3DDataset):
('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle', ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'), 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
'version': 'version':
'v1.0-trainval' 'v1.0-trainval',
'palette': [
(255, 158, 0), # Orange
(255, 99, 71), # Tomato
(255, 140, 0), # Darkorange
(255, 127, 80), # Coral
(233, 150, 70), # Darksalmon
(220, 20, 60), # Crimson
(255, 61, 99), # Red
(0, 0, 230), # Blue
(47, 79, 79), # Darkslategrey
(112, 128, 144), # Slategrey
]
} }
def __init__(self, def __init__(self,
......
...@@ -52,7 +52,10 @@ class S3DISDataset(Det3DDataset): ...@@ -52,7 +52,10 @@ class S3DISDataset(Det3DDataset):
'classes': ('table', 'chair', 'sofa', 'bookcase', 'board'), 'classes': ('table', 'chair', 'sofa', 'bookcase', 'board'),
# the valid ids of segmentation annotations # the valid ids of segmentation annotations
'seg_valid_class_ids': (7, 8, 9, 10, 11), 'seg_valid_class_ids': (7, 8, 9, 10, 11),
'seg_all_class_ids': tuple(range(1, 14)) # possibly with 'stair' class 'seg_all_class_ids':
tuple(range(1, 14)), # possibly with 'stair' class
'palette': [(170, 120, 200), (255, 0, 0), (200, 100, 100),
(10, 200, 100), (200, 200, 200)]
} }
def __init__(self, def __init__(self,
......
...@@ -57,7 +57,13 @@ class ScanNetDataset(Det3DDataset): ...@@ -57,7 +57,13 @@ class ScanNetDataset(Det3DDataset):
'seg_valid_class_ids': 'seg_valid_class_ids':
(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39), (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
'seg_all_class_ids': 'seg_all_class_ids':
tuple(range(1, 41)) tuple(range(1, 41)),
'palette': [(31, 119, 180), (255, 187, 120), (188, 189, 34),
(140, 86, 75), (255, 152, 150), (214, 39, 40),
(197, 176, 213), (148, 103, 189), (196, 156, 148),
(23, 190, 207), (247, 182, 210), (219, 219, 141),
(255, 127, 14), (158, 218, 229), (44, 160, 44),
(112, 128, 144), (227, 119, 194), (82, 84, 163)]
} }
def __init__(self, def __init__(self,
......
...@@ -119,12 +119,13 @@ class Seg3DDataset(BaseDataset): ...@@ -119,12 +119,13 @@ class Seg3DDataset(BaseDataset):
**kwargs) **kwargs)
self.metainfo['seg_label_mapping'] = self.seg_label_mapping self.metainfo['seg_label_mapping'] = self.seg_label_mapping
self.scene_idxs = self.get_scene_idxs(scene_idxs) if not kwargs.get('lazy_init', False):
self.data_list = [self.data_list[i] for i in self.scene_idxs] self.scene_idxs = self.get_scene_idxs(scene_idxs)
self.data_list = [self.data_list[i] for i in self.scene_idxs]
# set group flag for the sampler # set group flag for the sampler
if not self.test_mode: if not self.test_mode:
self._set_group_flag() self._set_group_flag()
def get_label_mapping(self, def get_label_mapping(self,
new_classes: Optional[Sequence] = None) -> tuple: new_classes: Optional[Sequence] = None) -> tuple:
......
...@@ -47,7 +47,11 @@ class SUNRGBDDataset(Det3DDataset): ...@@ -47,7 +47,11 @@ class SUNRGBDDataset(Det3DDataset):
""" """
METAINFO = { METAINFO = {
'classes': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'classes': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub') 'dresser', 'night_stand', 'bookshelf', 'bathtub'),
'palette': [(255, 187, 120), (255, 152, 150), (140, 86, 75),
(188, 189, 34), (44, 160, 44), (247, 182, 210),
(196, 156, 148), (23, 190, 207), (148, 103, 189),
(227, 119, 194)]
} }
def __init__(self, def __init__(self,
......
...@@ -68,7 +68,14 @@ class WaymoDataset(KittiDataset): ...@@ -68,7 +68,14 @@ class WaymoDataset(KittiDataset):
load_interval (int): load frame interval. Defaults to 1. load_interval (int): load frame interval. Defaults to 1.
max_sweeps (int): max sweep for each frame. Defaults to 0. max_sweeps (int): max sweep for each frame. Defaults to 0.
""" """
METAINFO = {'classes': ('Car', 'Pedestrian', 'Cyclist')} METAINFO = {
'classes': ('Car', 'Pedestrian', 'Cyclist'),
'palette': [
(0, 120, 255), # Waymo Blue
(0, 232, 157), # Waymo Green
(255, 205, 85) # Amber
]
}
def __init__(self, def __init__(self,
data_root: str, data_root: str,
......
...@@ -10,11 +10,12 @@ import numpy as np ...@@ -10,11 +10,12 @@ import numpy as np
from matplotlib.collections import PatchCollection from matplotlib.collections import PatchCollection
from matplotlib.patches import PathPatch from matplotlib.patches import PathPatch
from matplotlib.path import Path from matplotlib.path import Path
from mmdet.visualization import DetLocalVisualizer from mmdet.visualization import DetLocalVisualizer, get_palette
from mmengine.dist import master_only from mmengine.dist import master_only
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmengine.visualization import Visualizer as MMENGINE_Visualizer from mmengine.visualization import Visualizer as MMENGINE_Visualizer
from mmengine.visualization.utils import check_type, tensor2ndarray from mmengine.visualization.utils import (check_type, color_val_matplotlib,
tensor2ndarray)
from torch import Tensor from torch import Tensor
from mmdet3d.registry import VISUALIZERS from mmdet3d.registry import VISUALIZERS
...@@ -173,7 +174,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -173,7 +174,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
pcd_mode: int = 0, pcd_mode: int = 0,
vis_mode: str = 'replace', vis_mode: str = 'replace',
frame_cfg: dict = dict(size=1, origin=[0, 0, 0]), frame_cfg: dict = dict(size=1, origin=[0, 0, 0]),
points_color: Tuple[float] = (1, 1, 1), points_color: Tuple[float] = (0.8, 0.8, 0.8),
points_size: int = 2, points_size: int = 2,
mode: str = 'xyz') -> None: mode: str = 'xyz') -> None:
"""Set the point cloud to draw. """Set the point cloud to draw.
...@@ -295,7 +296,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -295,7 +296,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
line_set = geometry.LineSet.create_from_oriented_bounding_box( line_set = geometry.LineSet.create_from_oriented_bounding_box(
box3d) box3d)
line_set.paint_uniform_color(bbox_color) line_set.paint_uniform_color(np.array(bbox_color[i]) / 255.)
# draw bboxes on visualizer # draw bboxes on visualizer
self.o3d_vis.add_geometry(line_set) self.o3d_vis.add_geometry(line_set)
...@@ -509,6 +510,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -509,6 +510,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
raise NotImplementedError('unsupported box type!') raise NotImplementedError('unsupported box type!')
edge_colors_norm = color_val_matplotlib(edge_colors)
corners_2d = proj_bbox3d_to_img(bboxes_3d, input_meta) corners_2d = proj_bbox3d_to_img(bboxes_3d, input_meta)
if img_size is not None: if img_size is not None:
# Filter out the bbox where half of stuff is outside the image. # Filter out the bbox where half of stuff is outside the image.
...@@ -518,6 +521,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -518,6 +521,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
(corners_2d[..., 1] >= 0) & (corners_2d[..., 1] <= img_size[1]) # noqa: E501 (corners_2d[..., 1] >= 0) & (corners_2d[..., 1] <= img_size[1]) # noqa: E501
valid_bbox_idx = valid_point_idx.sum(axis=-1) >= 4 valid_bbox_idx = valid_point_idx.sum(axis=-1) >= 4
corners_2d = corners_2d[valid_bbox_idx] corners_2d = corners_2d[valid_bbox_idx]
filter_edge_colors = []
filter_edge_colors_norm = []
for i, color in enumerate(edge_colors):
if valid_bbox_idx[i]:
filter_edge_colors.append(color)
filter_edge_colors_norm.append(edge_colors_norm[i])
edge_colors = filter_edge_colors
edge_colors_norm = filter_edge_colors_norm
lines_verts_idx = [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 5, 1, 2, 6] lines_verts_idx = [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 5, 1, 2, 6]
lines_verts = corners_2d[:, lines_verts_idx, :] lines_verts = corners_2d[:, lines_verts_idx, :]
...@@ -533,7 +544,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -533,7 +544,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
p = PatchCollection( p = PatchCollection(
pathpatches, pathpatches,
facecolors='none', facecolors='none',
edgecolors=edge_colors, edgecolors=edge_colors_norm,
linewidths=line_widths, linewidths=line_widths,
linestyles=line_styles) linestyles=line_styles)
...@@ -547,7 +558,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -547,7 +558,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
edge_colors=edge_colors, edge_colors=edge_colors,
line_styles=line_styles, line_styles=line_styles,
line_widths=line_widths, line_widths=line_widths,
face_colors=face_colors) face_colors=edge_colors)
@master_only @master_only
def draw_seg_mask(self, seg_mask_colors: np.ndarray) -> None: def draw_seg_mask(self, seg_mask_colors: np.ndarray) -> None:
...@@ -598,6 +609,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -598,6 +609,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
return None return None
bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes
labels_3d = instances.labels_3d
data_3d = dict() data_3d = dict()
...@@ -612,8 +624,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -612,8 +624,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
else: else:
bboxes_3d_depth = bboxes_3d.clone() bboxes_3d_depth = bboxes_3d.clone()
max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.set_points(points, pcd_mode=2) self.set_points(points, pcd_mode=2)
self.draw_bboxes_3d(bboxes_3d_depth) self.draw_bboxes_3d(bboxes_3d_depth, bbox_color=colors)
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor) data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
data_3d['points'] = points data_3d['points'] = points
...@@ -646,10 +664,19 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -646,10 +664,19 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
single_img_meta[key] = meta[i] single_img_meta[key] = meta[i]
else: else:
single_img_meta[key] = meta single_img_meta[key] = meta
max_label = int(
max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.draw_proj_bboxes_3d( self.draw_proj_bboxes_3d(
bboxes_3d, bboxes_3d,
single_img_meta, single_img_meta,
img_size=single_img.shape[:2][::-1]) img_size=single_img.shape[:2][::-1],
edge_colors=colors)
if vis_task == 'mono_det' and hasattr( if vis_task == 'mono_det' and hasattr(
instances, 'centers_2d'): instances, 'centers_2d'):
centers_2d = instances.centers_2d centers_2d = instances.centers_2d
...@@ -668,7 +695,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -668,7 +695,15 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img.permute(1, 2, 0).numpy() img = img.permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
self.set_image(img) self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
max_label = int(max(labels_3d) if len(labels_3d) > 0 else 0)
bbox_color = palette if self.bbox_color is None \
else self.bbox_color
bbox_palette = get_palette(bbox_color, max_label + 1)
colors = [bbox_palette[label] for label in labels_3d]
self.draw_proj_bboxes_3d(
bboxes_3d, input_meta, edge_colors=colors)
if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'): if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'):
centers_2d = instances.centers_2d centers_2d = instances.centers_2d
self.draw_points(centers_2d) self.draw_points(centers_2d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment