Unverified Commit c2d958aa authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Refactor] update data flow and ut (#1776)

* update data flow and ut

* update ut

* update code

* fix mapping bug

* fix comments
parent c2c5abd6
...@@ -5,7 +5,7 @@ from os.path import dirname, exists, join ...@@ -5,7 +5,7 @@ from os.path import dirname, exists, join
import numpy as np import numpy as np
import torch import torch
from mmengine import InstanceData from mmengine.structures import InstanceData
from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes, from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes, Det3DDataSample, LiDARInstance3DBoxes,
...@@ -101,10 +101,13 @@ def _create_detector_inputs(seed=0, ...@@ -101,10 +101,13 @@ def _create_detector_inputs(seed=0,
[[5.23289349e+02, 3.68831943e+02, 6.10469439e+01], [[5.23289349e+02, 3.68831943e+02, 6.10469439e+01],
[1.09560138e+02, 1.97404735e+02, -5.47377738e+02], [1.09560138e+02, 1.97404735e+02, -5.47377738e+02],
[1.25930002e-02, 9.92229998e-01, -1.23769999e-01]]) [1.25930002e-02, 9.92229998e-01, -1.23769999e-01]])
inputs_dict = dict()
if with_points: if with_points:
points = torch.rand([num_points, points_feat_dim]) points = torch.rand([num_points, points_feat_dim])
else: inputs_dict['points'] = [points]
points = None
if with_img: if with_img:
if isinstance(img_size, tuple): if isinstance(img_size, tuple):
img = torch.rand(3, img_size[0], img_size[1]) img = torch.rand(3, img_size[0], img_size[1])
...@@ -115,10 +118,8 @@ def _create_detector_inputs(seed=0, ...@@ -115,10 +118,8 @@ def _create_detector_inputs(seed=0,
meta_info['img_shape'] = (img_size, img_size) meta_info['img_shape'] = (img_size, img_size)
meta_info['ori_shape'] = (img_size, img_size) meta_info['ori_shape'] = (img_size, img_size)
meta_info['scale_factor'] = np.array([1., 1.]) meta_info['scale_factor'] = np.array([1., 1.])
inputs_dict['img'] = [img]
else:
img = None
inputs_dict = dict(img=img, points=points)
gt_instance_3d = InstanceData() gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type]( gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
...@@ -145,4 +146,4 @@ def _create_detector_inputs(seed=0, ...@@ -145,4 +146,4 @@ def _create_detector_inputs(seed=0,
pts_semantic_mask = torch.randint(0, num_classes, [num_points]) pts_semantic_mask = torch.randint(0, num_classes, [num_points])
data_sample.gt_pts_seg['pts_semantic_mask'] = pts_semantic_mask data_sample.gt_pts_seg['pts_semantic_mask'] = pts_semantic_mask
return dict(inputs=inputs_dict, data_sample=data_sample) return dict(inputs=inputs_dict, data_samples=[data_sample])
...@@ -118,7 +118,7 @@ def main(): ...@@ -118,7 +118,7 @@ def main():
for item in dataset: for item in dataset:
# the 3D Boxes in input could be in any of three coordinates # the 3D Boxes in input could be in any of three coordinates
data_input = item['inputs'] data_input = item['inputs']
data_sample = item['data_sample'].numpy() data_sample = item['data_samples'].numpy()
out_file = osp.join( out_file = osp.join(
args.output_dir) if args.output_dir is not None else None args.output_dir) if args.output_dir is not None else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment