"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "388c520be21752cacb9fe3b1712038f32e0e9a5f"
Unverified Commit 77d16764 authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

fix vis hook bug (#1839)

parent 6607f2a7
...@@ -241,6 +241,7 @@ class Det3DDataset(BaseDataset): ...@@ -241,6 +241,7 @@ class Det3DDataset(BaseDataset):
self.data_prefix.get('pts', ''), self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path']) info['lidar_points']['lidar_path'])
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path'] info['lidar_path'] = info['lidar_points']['lidar_path']
if 'lidar_sweeps' in info: if 'lidar_sweeps' in info:
for sweep in info['lidar_sweeps']: for sweep in info['lidar_sweeps']:
......
...@@ -68,8 +68,9 @@ class Pack3DDetInputs(BaseTransform): ...@@ -68,8 +68,9 @@ class Pack3DDetInputs(BaseTransform):
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'pcd_trans', 'sample_idx', 'pcd_scale_factor', 'num_pts_feats', 'pcd_trans', 'sample_idx',
'pcd_rotation', 'pcd_rotation_angle', 'lidar_path', 'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat', 'transformation_3d_flow', 'trans_mat',
'affine_aug')): 'affine_aug')):
self.keys = keys self.keys = keys
......
...@@ -4,6 +4,7 @@ import warnings ...@@ -4,6 +4,7 @@ import warnings
from typing import Optional, Sequence from typing import Optional, Sequence
import mmcv import mmcv
import numpy as np
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.runner import Runner from mmengine.runner import Runner
...@@ -95,15 +96,27 @@ class Det3DVisualizationHook(Hook): ...@@ -95,15 +96,27 @@ class Det3DVisualizationHook(Hook):
# is visualized for each evaluation. # is visualized for each evaluation.
total_curr_iter = runner.iter + batch_idx total_curr_iter = runner.iter + batch_idx
data_input = dict()
# Visualize only the first data # Visualize only the first data
img_path = outputs[0].img_path if 'img_path' in outputs[0]:
img_bytes = self.file_client.get(img_path) img_path = outputs[0].img_path
img = mmcv.imfrombytes(img_bytes, channel_order='rgb') img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input['img'] = img
if 'lidar_path' in outputs[0]:
lidar_path = outputs[0].lidar_path
num_pts_feats = outputs[0].num_pts_feats
pts_bytes = self.file_client.get(lidar_path)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, num_pts_feats)
data_input['points'] = points
if total_curr_iter % self.interval == 0: if total_curr_iter % self.interval == 0:
self._visualizer.add_datasample( self._visualizer.add_datasample(
osp.basename(img_path) if self.show else 'val_img', 'val sample',
img, data_input,
data_sample=outputs[0], data_sample=outputs[0],
show=self.show, show=self.show,
wait_time=self.wait_time, wait_time=self.wait_time,
...@@ -135,18 +148,28 @@ class Det3DVisualizationHook(Hook): ...@@ -135,18 +148,28 @@ class Det3DVisualizationHook(Hook):
for data_sample in outputs: for data_sample in outputs:
self._test_index += 1 self._test_index += 1
img_path = data_sample.img_path data_input = dict()
img_bytes = self.file_client.get(img_path) if 'img_path' in data_sample:
img = mmcv.imfrombytes(img_bytes, channel_order='rgb') img_path = data_sample.img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input['img'] = img
if 'lidar_path' in data_sample:
lidar_path = data_sample.lidar_path
num_pts_feats = data_sample.num_pts_feats
pts_bytes = self.file_client.get(lidar_path)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, num_pts_feats)
data_input['points'] = points
out_file = None out_file = None
if self.test_out_dir is not None: if self.test_out_dir is not None:
out_file = osp.basename(img_path) out_file = self.test_out_dir
out_file = osp.join(self.test_out_dir, out_file)
self._visualizer.add_datasample( self._visualizer.add_datasample(
osp.basename(img_path) if self.show else 'test_img', 'test sample',
img, data_input,
data_sample=data_sample, data_sample=data_sample,
show=self.show, show=self.show,
wait_time=self.wait_time, wait_time=self.wait_time,
......
...@@ -525,14 +525,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -525,14 +525,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
data_input, data_sample.gt_instances_3d, data_input, data_sample.gt_instances_3d,
data_sample.metainfo, vis_task, palette) data_sample.metainfo, vis_task, palette)
if 'gt_instances' in data_sample: if 'gt_instances' in data_sample:
assert 'img' in data_input if len(data_sample.gt_instances) > 0:
if isinstance(data_input['img'], Tensor): assert 'img' in data_input
img = data_input['img'].permute(1, 2, 0).numpy() if isinstance(data_input['img'], Tensor):
img = img[..., [2, 1, 0]] # bgr to rgb img = data_input['img'].permute(1, 2, 0).numpy()
gt_img_data = self._draw_instances(img, img = img[..., [2, 1, 0]] # bgr to rgb
data_sample.gt_instances, gt_img_data = self._draw_instances(
classes, palette) img, data_sample.gt_instances, classes, palette)
if 'gt_pts_seg' in data_sample: if 'gt_pts_seg' in data_sample and vis_task == 'seg':
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
......
...@@ -5,42 +5,59 @@ import time ...@@ -5,42 +5,59 @@ import time
from unittest import TestCase from unittest import TestCase
from unittest.mock import Mock from unittest.mock import Mock
import numpy as np
import torch import torch
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmdet3d.engine.hooks import Det3DVisualizationHook from mmdet3d.engine.hooks import Det3DVisualizationHook
from mmdet3d.structures import Det3DDataSample from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
from mmdet3d.visualization import Det3DLocalVisualizer from mmdet3d.visualization import Det3DLocalVisualizer
def _rand_bboxes(num_boxes, h, w):
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
bboxes = torch.vstack([tl_x, tl_y, br_x, br_y]).T
return bboxes
class TestVisualizationHook(TestCase): class TestVisualizationHook(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
Det3DLocalVisualizer.get_instance('visualizer') Det3DLocalVisualizer.get_instance('visualizer')
pred_instances = InstanceData() pred_instances_3d = InstanceData()
pred_instances.bboxes = _rand_bboxes(5, 10, 12) pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
pred_instances.labels = torch.randint(0, 2, (5, )) torch.tensor(
pred_instances.scores = torch.rand((5, )) [[8.7314, -1.8559, -1.5997, 1.2000, 0.4800, 1.8900, -1.5808]]))
pred_det_data_sample = Det3DDataSample() pred_instances_3d.labels_3d = torch.tensor([0])
pred_det_data_sample.set_metainfo({ pred_instances_3d.scores_3d = torch.tensor([0.8])
pred_det3d_data_sample = Det3DDataSample()
pred_det3d_data_sample.set_metainfo({
'num_pts_feats':
4,
'lidar2img':
np.array([[
6.02943734e+02, -7.07913286e+02, -1.22748427e+01,
-1.70942724e+02
],
[
1.76777261e+02, 8.80879902e+00, -7.07936120e+02,
-1.02568636e+02
],
[
9.99984860e-01, -1.52826717e-03, -5.29071223e-03,
-3.27567990e-01
],
[
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.00000000e+00
]]),
'img_path': 'img_path':
osp.join(osp.dirname(__file__), '../../data/color.jpg') osp.join(
osp.dirname(__file__),
'../../data/kitti/training/image_2/000000.png'),
'lidar_path':
osp.join(
osp.dirname(__file__),
'../../data/kitti/training/velodyne_reduced/000000.bin')
}) })
pred_det_data_sample.pred_instances = pred_instances pred_det3d_data_sample.pred_instances_3d = pred_instances_3d
self.outputs = [pred_det_data_sample] * 2 self.outputs = [pred_det3d_data_sample] * 2
def test_after_val_iter(self): def test_after_val_iter(self):
runner = Mock() runner = Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment