Unverified Commit b297f21d authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] Fix demo bugs due to data flow refactor (#1799)

* fix bugs

* fix bugs 2

* fix seg demo

* update use collate data in mmengine
parent 2f4ea2ef
...@@ -94,7 +94,8 @@ ...@@ -94,7 +94,8 @@
"visualizer.add_datasample(\n", "visualizer.add_datasample(\n",
" 'result',\n", " 'result',\n",
" data_input,\n", " data_input,\n",
" pred_sample=result,\n", " data_sample=result,\n",
" draw_gt=False,\n",
" show=True,\n", " show=True,\n",
" wait_time=0,\n", " wait_time=0,\n",
" out_file=out_dir,\n", " out_file=out_dir,\n",
......
...@@ -60,7 +60,8 @@ def main(args): ...@@ -60,7 +60,8 @@ def main(args):
visualizer.add_datasample( visualizer.add_datasample(
'result', 'result',
data_input, data_input,
pred_sample=result, data_sample=result,
draw_gt=False,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
......
...@@ -61,7 +61,8 @@ def main(args): ...@@ -61,7 +61,8 @@ def main(args):
visualizer.add_datasample( visualizer.add_datasample(
'result', 'result',
data_input, data_input,
pred_sample=result, data_sample=result,
draw_gt=False,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
......
...@@ -50,7 +50,8 @@ def main(args): ...@@ -50,7 +50,8 @@ def main(args):
visualizer.add_datasample( visualizer.add_datasample(
'result', 'result',
data_input, data_input,
pred_sample=result, data_sample=result,
draw_gt=False,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
......
...@@ -46,7 +46,8 @@ def main(args): ...@@ -46,7 +46,8 @@ def main(args):
visualizer.add_datasample( visualizer.add_datasample(
'result', 'result',
data_input, data_input,
pred_sample=result, data_sample=result,
draw_gt=False,
show=True, show=True,
wait_time=0, wait_time=0,
out_file=args.out_dir, out_file=args.out_dir,
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.config import Config from mmengine.config import Config
from mmengine.dataset import Compose from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -60,8 +60,7 @@ def init_model(config: Union[str, Path, Config], ...@@ -60,8 +60,7 @@ def init_model(config: Union[str, Path, Config],
f'but got {type(config)}') f'but got {type(config)}')
if cfg_options is not None: if cfg_options is not None:
config.merge_from_dict(cfg_options) config.merge_from_dict(cfg_options)
elif 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
convert_SyncBN(config.model) convert_SyncBN(config.model)
config.model.train_cfg = None config.model.train_cfg = None
model = MODELS.build(config.model) model = MODELS.build(config.model)
...@@ -159,9 +158,11 @@ def inference_detector(model: nn.Module, ...@@ -159,9 +158,11 @@ def inference_detector(model: nn.Module,
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
data.append(data_) data.append(data_)
collate_data = pseudo_collate(data)
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
results = model.test_step(data) results = model.test_step(collate_data)
if not is_batch: if not is_batch:
return results[0], data[0] return results[0], data[0]
...@@ -245,13 +246,11 @@ def inference_multi_modality_detector(model: nn.Module, ...@@ -245,13 +246,11 @@ def inference_multi_modality_detector(model: nn.Module,
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
data.append(data_) data.append(data_)
collate_data = pseudo_collate(data)
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
results = model.test_step(data) results = model.test_step(collate_data)
for index in range(len(data)):
meta_info = data[index]['data_samples'].metainfo
results[index].set_metainfo(meta_info)
if not is_batch: if not is_batch:
return results[0], data[0] return results[0], data[0]
...@@ -315,13 +314,11 @@ def inference_mono_3d_detector(model: nn.Module, ...@@ -315,13 +314,11 @@ def inference_mono_3d_detector(model: nn.Module,
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
data.append(data_) data.append(data_)
collate_data = pseudo_collate(data)
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
results = model.test_step(data) results = model.test_step(collate_data)
for index in range(len(data)):
meta_info = data[index]['data_samples'].metainfo
results[index].set_metainfo(meta_info)
if not is_batch: if not is_batch:
return results[0] return results[0]
...@@ -352,7 +349,12 @@ def inference_segmentor(model: nn.Module, pcds: PointsType): ...@@ -352,7 +349,12 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(test_pipeline)
new_test_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] != 'LoadAnnotations3D':
new_test_pipeline.append(pipeline)
test_pipeline = Compose(new_test_pipeline)
data = [] data = []
# TODO: support load points array # TODO: support load points array
...@@ -361,9 +363,11 @@ def inference_segmentor(model: nn.Module, pcds: PointsType): ...@@ -361,9 +363,11 @@ def inference_segmentor(model: nn.Module, pcds: PointsType):
data_ = test_pipeline(data_) data_ = test_pipeline(data_)
data.append(data_) data.append(data_)
collate_data = pseudo_collate(data)
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
results = model.test_step(data) results = model.test_step(collate_data)
if not is_batch: if not is_batch:
return results[0], data[0] return results[0], data[0]
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import List, Tuple from typing import List, Tuple, Union
from mmengine.model import BaseModel from mmengine.model import BaseModel
from torch import Tensor from torch import Tensor
...@@ -64,8 +64,8 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta): ...@@ -64,8 +64,8 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
pass pass
def forward(self, def forward(self,
batch_inputs_dict: Tensor, inputs: Union[dict, List[dict]],
batch_data_samples: OptSampleList = None, data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults: mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test. """The unified entry for a forward process in both training and test.
...@@ -82,12 +82,12 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta): ...@@ -82,12 +82,12 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
optimizer updating, which are done in the :meth:`train_step`. optimizer updating, which are done in the :meth:`train_step`.
Args: Args:
batch_inputs_dict (dict): Input sample dict which inputs (dict | List[dict]): Input sample dict which
includes 'points' and 'imgs' keys. includes 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample. - points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W). - imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
batch_data_samples (list[:obj:`Det3DDataSample`], optional): data_samples (list[:obj:`Det3DDataSample`], optional):
The annotation data of every samples. Defaults to None. The annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'. mode (str): Return what kind of value. Defaults to 'tensor'.
...@@ -99,11 +99,11 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta): ...@@ -99,11 +99,11 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
if mode == 'loss': if mode == 'loss':
return self.loss(batch_inputs_dict, batch_data_samples) return self.loss(inputs, data_samples)
elif mode == 'predict': elif mode == 'predict':
return self.predict(batch_inputs_dict, batch_data_samples) return self.predict(inputs, data_samples)
elif mode == 'tensor': elif mode == 'tensor':
return self._forward(batch_inputs_dict, batch_data_samples) return self._forward(inputs, data_samples)
else: else:
raise RuntimeError(f'Invalid mode "{mode}". ' raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode') 'Only supports loss, predict and tensor mode')
......
...@@ -354,9 +354,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -354,9 +354,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes bboxes_3d = instances.bboxes_3d # BaseInstance3DBoxes
drawn_img = None data_3d = dict()
drawn_points = None
drawn_bboxes_3d = None
if vis_task in ['det', 'multi_modality-det']: if vis_task in ['det', 'multi_modality-det']:
assert 'points' in data_input assert 'points' in data_input
...@@ -372,8 +370,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -372,8 +370,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.set_points(points, pcd_mode=2, vis_task=vis_task) self.set_points(points, pcd_mode=2, vis_task=vis_task)
self.draw_bboxes_3d(bboxes_3d_depth) self.draw_bboxes_3d(bboxes_3d_depth)
drawn_bboxes_3d = tensor2ndarray(bboxes_3d_depth.tensor) data_3d['bboxes_3d'] = tensor2ndarray(bboxes_3d_depth.tensor)
drawn_points = points data_3d['points'] = points
if vis_task in ['mono-det', 'multi_modality-det']: if vis_task in ['mono-det', 'multi_modality-det']:
assert 'img' in data_input assert 'img' in data_input
...@@ -384,9 +382,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -384,9 +382,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
self.set_image(img) self.set_image(img)
self.draw_proj_bboxes_3d(bboxes_3d, input_meta) self.draw_proj_bboxes_3d(bboxes_3d, input_meta)
drawn_img = self.get_image() drawn_img = self.get_image()
data_3d['img'] = drawn_img
data_3d = dict(
points=drawn_points, img=drawn_img, bboxes_3d=drawn_bboxes_3d)
return data_3d return data_3d
def _draw_pts_sem_seg(self, def _draw_pts_sem_seg(self,
...@@ -578,6 +575,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -578,6 +575,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
palette, ignore_index) palette, ignore_index)
# monocular 3d object detection image # monocular 3d object detection image
if vis_task in ['mono-det', 'multi_modality-det']:
if gt_data_3d is not None and pred_data_3d is not None: if gt_data_3d is not None and pred_data_3d is not None:
drawn_img_3d = np.concatenate( drawn_img_3d = np.concatenate(
(gt_data_3d['img'], pred_data_3d['img']), axis=1) (gt_data_3d['img'], pred_data_3d['img']), axis=1)
...@@ -618,6 +616,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -618,6 +616,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
write_oriented_bbox(gt_data_3d['bboxes_3d'], write_oriented_bbox(gt_data_3d['bboxes_3d'],
osp.join(out_file, 'gt_bbox.obj')) osp.join(out_file, 'gt_bbox.obj'))
if pred_data_3d is not None: if pred_data_3d is not None:
if 'points' in pred_data_3d:
write_obj(pred_data_3d['points'], write_obj(pred_data_3d['points'],
osp.join(out_file, 'points.obj')) osp.join(out_file, 'points.obj'))
write_oriented_bbox(pred_data_3d['bboxes_3d'], write_oriented_bbox(pred_data_3d['bboxes_3d'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment