Commit 6f70f7e7 authored by Shilong Zhang's avatar Shilong Zhang Committed by ChaimZhu
Browse files

[Fix] fix pillars encode (#1689)



* fix pillars encode

* fix all configs

* fix load key
Co-authored-by: default avatarVVsssssk <shenkun@pjlab.org.cn>
parent 970589c5
......@@ -31,13 +31,9 @@ test_pipeline = [
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D', sync_2d=False),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
data = dict(
......
......@@ -33,13 +33,9 @@ test_pipeline = [
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D', sync_2d=False),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
data = dict(
......
......@@ -33,12 +33,8 @@ test_pipeline = [
dict(type='RandomFlip3D', sync_2d=False),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
data = dict(
......
......@@ -112,23 +112,7 @@ test_pipeline = [
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
use_dim=[0, 1, 2, 3, 4],
pad_empty_sweeps=True,
remove_close=True),
dict(type='Pack3DDetInputs', keys=['points'])
]
train_dataloader = dict(
_delete_=True,
batch_size=4,
......
......@@ -111,18 +111,7 @@ test_pipeline = [
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
use_dim=[0, 1, 2, 3, 4],
pad_empty_sweeps=True,
remove_close=True),
dict(type='Pack3DDetInputs', keys=['points'])
]
train_dataloader = dict(
_delete_=True,
batch_size=4,
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmdet3d.models.test_time_augs import merge_aug_bboxes_3d
from mmdet3d.registry import MODELS
from .mvx_two_stage import MVXTwoStageDetector
......@@ -70,118 +67,3 @@ class CenterPoint(MVXTwoStageDetector):
pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, init_cfg, data_preprocessor,
**kwargs)
# TODO support this
def aug_test_pts(self, feats, img_metas, rescale=False):
"""Test function of point cloud branch with augmentaiton.
The function implementation process is as follows:
- step 1: map features back for double-flip augmentation.
- step 2: merge all features and generate boxes.
- step 3: map boxes back for scale augmentation.
- step 4: merge results.
Args:
feats (list[torch.Tensor]): Feature of point cloud.
img_metas (list[dict]): Meta information of samples.
rescale (bool, optional): Whether to rescale bboxes.
Default: False.
Returns:
dict: Returned bboxes consists of the following keys:
- boxes_3d (:obj:`LiDARInstance3DBoxes`): Predicted bboxes.
- scores_3d (torch.Tensor): Scores of predicted boxes.
- labels_3d (torch.Tensor): Labels of predicted boxes.
"""
raise NotImplementedError
# only support aug_test for one sample
outs_list = []
for x, img_meta in zip(feats, img_metas):
outs = self.pts_bbox_head(x)
# merge augmented outputs before decoding bboxes
for task_id, out in enumerate(outs):
for key in out[0].keys():
if img_meta[0]['pcd_horizontal_flip']:
outs[task_id][0][key] = torch.flip(
outs[task_id][0][key], dims=[2])
if key == 'reg':
outs[task_id][0][key][:, 1, ...] = 1 - outs[
task_id][0][key][:, 1, ...]
elif key == 'rot':
outs[task_id][0][
key][:, 0,
...] = -outs[task_id][0][key][:, 0, ...]
elif key == 'vel':
outs[task_id][0][
key][:, 1,
...] = -outs[task_id][0][key][:, 1, ...]
if img_meta[0]['pcd_vertical_flip']:
outs[task_id][0][key] = torch.flip(
outs[task_id][0][key], dims=[3])
if key == 'reg':
outs[task_id][0][key][:, 0, ...] = 1 - outs[
task_id][0][key][:, 0, ...]
elif key == 'rot':
outs[task_id][0][
key][:, 1,
...] = -outs[task_id][0][key][:, 1, ...]
elif key == 'vel':
outs[task_id][0][
key][:, 0,
...] = -outs[task_id][0][key][:, 0, ...]
outs_list.append(outs)
preds_dicts = dict()
scale_img_metas = []
# concat outputs sharing the same pcd_scale_factor
for i, (img_meta, outs) in enumerate(zip(img_metas, outs_list)):
pcd_scale_factor = img_meta[0]['pcd_scale_factor']
if pcd_scale_factor not in preds_dicts.keys():
preds_dicts[pcd_scale_factor] = outs
scale_img_metas.append(img_meta)
else:
for task_id, out in enumerate(outs):
for key in out[0].keys():
preds_dicts[pcd_scale_factor][task_id][0][key] += out[
0][key]
aug_bboxes = []
for pcd_scale_factor, preds_dict in preds_dicts.items():
for task_id, pred_dict in enumerate(preds_dict):
# merge outputs with different flips before decoding bboxes
for key in pred_dict[0].keys():
preds_dict[task_id][0][key] /= len(outs_list) / len(
preds_dicts.keys())
bbox_list = self.pts_bbox_head.get_bboxes(
preds_dict, img_metas[0], rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
if len(preds_dicts.keys()) > 1:
# merge outputs with different scales after decoding bboxes
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, scale_img_metas,
self.pts_bbox_head.test_cfg)
return merged_bboxes
else:
for key in bbox_list[0].keys():
bbox_list[0][key] = bbox_list[0][key].to('cpu')
return bbox_list[0]
# TODO support this
def aug_test(self, points, img_metas, imgs=None, rescale=False):
raise NotImplementedError
"""Test function with augmentaiton."""
img_feats, pts_feats = self.extract_feats(points, img_metas, imgs)
bbox_list = dict()
if pts_feats and self.with_pts_bbox:
pts_bbox = self.aug_test_pts(pts_feats, img_metas, rescale)
bbox_list.update(pts_bbox=pts_bbox)
return [bbox_list]
......@@ -91,7 +91,7 @@ class PillarFeatureNet(nn.Module):
self.point_cloud_range = point_cloud_range
@force_fp32(out_fp16=True)
def forward(self, features, num_points, coors):
def forward(self, features, num_points, coors, *args, **kwargs):
"""Forward function.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment