Unverified Commit b297f21d authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] Fix demo bugs due to data flow refactor (#1799)

* fix bugs

* fix bugs 2

* fix seg demo

* update use collate data in mmengine
parent 2f4ea2ef
......@@ -94,7 +94,8 @@
"visualizer.add_datasample(\n",
" 'result',\n",
" data_input,\n",
" pred_sample=result,\n",
" data_sample=result,\n",
" draw_gt=False,\n",
" show=True,\n",
" wait_time=0,\n",
" out_file=out_dir,\n",
......
......@@ -60,7 +60,8 @@ def main(args):
visualizer.add_datasample(
'result',
data_input,
pred_sample=result,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_dir,
......
......@@ -61,7 +61,8 @@ def main(args):
visualizer.add_datasample(
'result',
data_input,
pred_sample=result,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_dir,
......
......@@ -50,7 +50,8 @@ def main(args):
visualizer.add_datasample(
'result',
data_input,
pred_sample=result,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_dir,
......
......@@ -46,7 +46,8 @@ def main(args):
visualizer.add_datasample(
'result',
data_input,
pred_sample=result,
data_sample=result,
draw_gt=False,
show=True,
wait_time=0,
out_file=args.out_dir,
......
......@@ -10,7 +10,7 @@ import numpy as np
import torch
import torch.nn as nn
from mmengine.config import Config
from mmengine.dataset import Compose
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner import load_checkpoint
from mmdet3d.registry import MODELS
......@@ -60,8 +60,7 @@ def init_model(config: Union[str, Path, Config],
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
elif 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
convert_SyncBN(config.model)
config.model.train_cfg = None
model = MODELS.build(config.model)
......@@ -159,9 +158,11 @@ def inference_detector(model: nn.Module,
data_ = test_pipeline(data_)
data.append(data_)
collate_data = pseudo_collate(data)
# forward the model
with torch.no_grad():
results = model.test_step(data)
results = model.test_step(collate_data)
if not is_batch:
return results[0], data[0]
......@@ -245,13 +246,11 @@ def inference_multi_modality_detector(model: nn.Module,
data_ = test_pipeline(data_)
data.append(data_)
collate_data = pseudo_collate(data)
# forward the model
with torch.no_grad():
results = model.test_step(data)
for index in range(len(data)):
meta_info = data[index]['data_samples'].metainfo
results[index].set_metainfo(meta_info)
results = model.test_step(collate_data)
if not is_batch:
return results[0], data[0]
......@@ -315,13 +314,11 @@ def inference_mono_3d_detector(model: nn.Module,
data_ = test_pipeline(data_)
data.append(data_)
collate_data = pseudo_collate(data)
# forward the model
with torch.no_grad():
results = model.test_step(data)
for index in range(len(data)):
meta_info = data[index]['data_samples'].metainfo
results[index].set_metainfo(meta_info)
results = model.test_step(collate_data)
if not is_batch:
return results[0]
......@@ -352,7 +349,12 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
# build the data pipeline
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline)
new_test_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] != 'LoadAnnotations3D':
new_test_pipeline.append(pipeline)
test_pipeline = Compose(new_test_pipeline)
data = []
# TODO: support load points array
......@@ -361,9 +363,11 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
data_ = test_pipeline(data_)
data.append(data_)
collate_data = pseudo_collate(data)
# forward the model
with torch.no_grad():
results = model.test_step(data)
results = model.test_step(collate_data)
if not is_batch:
return results[0], data[0]
......
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
from typing import List, Tuple, Union
from mmengine.model import BaseModel
from torch import Tensor
......@@ -64,8 +64,8 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
pass
def forward(self,
batch_inputs_dict: Tensor,
batch_data_samples: OptSampleList = None,
inputs: Union[dict, List[dict]],
data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
......@@ -82,12 +82,12 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs_dict (dict): Input sample dict which
inputs (dict | List[dict]): Input sample dict which
includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`], optional):
data_samples (list[:obj:`Det3DDataSample`], optional):
The annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
......@@ -99,11 +99,11 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs_dict, batch_data_samples)
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(batch_inputs_dict, batch_data_samples)
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(batch_inputs_dict, batch_data_samples)
return self._forward(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
......
......@@ -354,9 +354,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes
drawn_img = None
drawn_points = None
drawn_bboxes_3d = None
data_3d = dict()
if vis_task in ['det', 'multi_modality-det']:
assert 'points' in data_input
......@@ -372,8 +370,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.set_points(points, pcd_mode=2, vis_task=vis_task)
self.draw_bboxes_3d(bboxes_3d_depth)
drawn_bboxes_3d = tensor2ndarray(bboxes_3d_depth.tensor)
drawn_points = points
data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
data_3d['points'] = points
if vis_task in ['mono-det', 'multi_modality-det']:
assert 'img' in data_input
......@@ -384,9 +382,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
drawn_img = self.get_image()
data_3d['img'] = drawn_img
data_3d = dict(
points=drawn_points, img=drawn_img, bboxes_3d=drawn_bboxes_3d)
return data_3d
def _draw_pts_sem_seg(self,
......@@ -578,13 +575,14 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
palette, ignore_index)
# monocular 3d object detection image
if gt_data_3d is not None and pred_data_3d is not None:
drawn_img_3d = np.concatenate(
(gt_data_3d['img'], pred_data_3d['img']), axis=1)
elif gt_data_3d is not None:
drawn_img_3d = gt_data_3d['img']
elif pred_data_3d is not None:
drawn_img_3d = pred_data_3d['img']
if vis_task in ['mono-det', 'multi_modality-det']:
if gt_data_3d is not None and pred_data_3d is not None:
drawn_img_3d = np.concatenate(
(gt_data_3d['img'], pred_data_3d['img']), axis=1)
elif gt_data_3d is not None:
drawn_img_3d = gt_data_3d['img']
elif pred_data_3d is not None:
drawn_img_3d = pred_data_3d['img']
else:
drawn_img_3d = None
......@@ -618,10 +616,11 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
write_oriented_bbox(gt_data_3d['bboxes_3d'],
osp.join(out_file, 'gt_bbox.obj'))
if pred_data_3d is not None:
write_obj(pred_data_3d['points'],
osp.join(out_file, 'points.obj'))
write_oriented_bbox(pred_data_3d['bboxes_3d'],
osp.join(out_file, 'pred_bbox.obj'))
if 'points' in pred_data_3d:
write_obj(pred_data_3d['points'],
osp.join(out_file, 'points.obj'))
write_oriented_bbox(pred_data_3d['bboxes_3d'],
osp.join(out_file, 'pred_bbox.obj'))
if gt_seg_data_3d is not None:
write_obj(gt_seg_data_3d['points'],
osp.join(out_file, 'points.obj'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment