"vscode:/vscode.git/clone" did not exist on "4ef91280e4439a727a328f08a82ba680567f987f"
Unverified Commit 77d16764 authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

fix vis hook bug (#1839)

parent 6607f2a7
......@@ -241,6 +241,7 @@ class Det3DDataset(BaseDataset):
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path']
if 'lidar_sweeps' in info:
for sweep in info['lidar_sweeps']:
......
......@@ -68,8 +68,9 @@ class Pack3DDetInputs(BaseTransform):
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'pcd_trans', 'sample_idx', 'pcd_scale_factor',
'pcd_rotation', 'pcd_rotation_angle', 'lidar_path',
'num_pts_feats', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')):
self.keys = keys
......
......@@ -4,6 +4,7 @@ import warnings
from typing import Optional, Sequence
import mmcv
import numpy as np
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.runner import Runner
......@@ -95,15 +96,27 @@ class Det3DVisualizationHook(Hook):
# is visualized for each evaluation.
total_curr_iter = runner.iter + batch_idx
data_input = dict()
# Visualize only the first data
img_path = outputs[0].img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
if 'img_path' in outputs[0]:
img_path = outputs[0].img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input['img'] = img
if 'lidar_path' in outputs[0]:
lidar_path = outputs[0].lidar_path
num_pts_feats = outputs[0].num_pts_feats
pts_bytes = self.file_client.get(lidar_path)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, num_pts_feats)
data_input['points'] = points
if total_curr_iter % self.interval == 0:
self._visualizer.add_datasample(
osp.basename(img_path) if self.show else 'val_img',
img,
'val sample',
data_input,
data_sample=outputs[0],
show=self.show,
wait_time=self.wait_time,
......@@ -135,18 +148,28 @@ class Det3DVisualizationHook(Hook):
for data_sample in outputs:
self._test_index += 1
img_path = data_sample.img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input = dict()
if 'img_path' in data_sample:
img_path = data_sample.img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input['img'] = img
if 'lidar_path' in data_sample:
lidar_path = data_sample.lidar_path
num_pts_feats = data_sample.num_pts_feats
pts_bytes = self.file_client.get(lidar_path)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, num_pts_feats)
data_input['points'] = points
out_file = None
if self.test_out_dir is not None:
out_file = osp.basename(img_path)
out_file = osp.join(self.test_out_dir, out_file)
out_file = self.test_out_dir
self._visualizer.add_datasample(
osp.basename(img_path) if self.show else 'test_img',
img,
'test sample',
data_input,
data_sample=data_sample,
show=self.show,
wait_time=self.wait_time,
......
......@@ -525,14 +525,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_input, data_sample.gt_instances_3d,
data_sample.metainfo, vis_task, palette)
if 'gt_instances' in data_sample:
assert 'img' in data_input
if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb
gt_img_data = self._draw_instances(img,
data_sample.gt_instances,
classes, palette)
if 'gt_pts_seg' in data_sample:
if len(data_sample.gt_instances) > 0:
assert 'img' in data_input
if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb
gt_img_data = self._draw_instances(
img, data_sample.gt_instances, classes, palette)
if 'gt_pts_seg' in data_sample and vis_task == 'seg':
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing panoptic ' \
......
......@@ -5,42 +5,59 @@ import time
from unittest import TestCase
from unittest.mock import Mock
import numpy as np
import torch
from mmengine.structures import InstanceData
from mmdet3d.engine.hooks import Det3DVisualizationHook
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
from mmdet3d.visualization import Det3DLocalVisualizer
def _rand_bboxes(num_boxes, h, w):
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
bboxes = torch.vstack([tl_x, tl_y, br_x, br_y]).T
return bboxes
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
Det3DLocalVisualizer.get_instance('visualizer')
pred_instances = InstanceData()
pred_instances.bboxes = _rand_bboxes(5, 10, 12)
pred_instances.labels = torch.randint(0, 2, (5, ))
pred_instances.scores = torch.rand((5, ))
pred_det_data_sample = Det3DDataSample()
pred_det_data_sample.set_metainfo({
pred_instances_3d = InstanceData()
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 1.2000, 0.4800, 1.8900, -1.5808]]))
pred_instances_3d.labels_3d = torch.tensor([0])
pred_instances_3d.scores_3d = torch.tensor([0.8])
pred_det3d_data_sample = Det3DDataSample()
pred_det3d_data_sample.set_metainfo({
'num_pts_feats':
4,
'lidar2img':
np.array([[
6.02943734e+02, -7.07913286e+02, -1.22748427e+01,
-1.70942724e+02
],
[
1.76777261e+02, 8.80879902e+00, -7.07936120e+02,
-1.02568636e+02
],
[
9.99984860e-01, -1.52826717e-03, -5.29071223e-03,
-3.27567990e-01
],
[
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.00000000e+00
]]),
'img_path':
osp.join(osp.dirname(__file__), '../../data/color.jpg')
osp.join(
osp.dirname(__file__),
'../../data/kitti/training/image_2/000000.png'),
'lidar_path':
osp.join(
osp.dirname(__file__),
'../../data/kitti/training/velodyne_reduced/000000.bin')
})
pred_det_data_sample.pred_instances = pred_instances
self.outputs = [pred_det_data_sample] * 2
pred_det3d_data_sample.pred_instances_3d = pred_instances_3d
self.outputs = [pred_det3d_data_sample] * 2
def test_after_val_iter(self):
runner = Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment