Unverified Commit 12c3b19d authored by SekiroRong's avatar SekiroRong Committed by GitHub
Browse files

[fix] fix bug in projects/PETR (#2212)

parent 8bf2f5a4
......@@ -8,14 +8,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
import numpy as np
import torch
from mmengine.structures import InstanceData
import mmdet3d
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import LiDARInstance3DBoxes, limit_period
from mmdet3d.structures.ops import bbox3d2result
from .grid_mask import GridMask
......@@ -174,8 +171,6 @@ class PETR(MVXTwoStageDetector):
gt_labels_3d = [gt.labels_3d for gt in batch_gt_instances_3d]
gt_bboxes_ignore = None
gt_bboxes_3d = self.LidarBox3dVersionTransfrom(gt_bboxes_3d)
batch_img_metas = self.add_lidar2img(img, batch_img_metas)
img_feats = self.extract_feat(img=img, img_metas=batch_img_metas)
......@@ -265,8 +260,8 @@ class PETR(MVXTwoStageDetector):
Returns:
batch_input_metas (list[dict]): Meta info with lidar2img added
"""
lidar2img_rts = []
for meta in batch_input_metas:
lidar2img_rts = []
# obtain lidar to image transformation matrix
for i in range(len(meta['cam2img'])):
lidar2cam_rt = torch.tensor(meta['lidar2cam'][i]).double()
......@@ -281,19 +276,7 @@ class PETR(MVXTwoStageDetector):
# and LoadMultiViewImageFromMultiSweepsFiles.
lidar2img_rts.append(lidar2img_rt)
meta['lidar2img'] = lidar2img_rts
meta['img_shape'] = [i.shape for i in img[0]]
return batch_input_metas
img_shape = meta['img_shape'][:3]
meta['img_shape'] = [img_shape] * len(img[0])
def LidarBox3dVersionTransfrom(self, gt_bboxes_3d):
if int(mmdet3d.__version__[0]) >= 1:
# Begin hack adaptation to mmdet3d v1.0 ####
gt_bboxes_3d = gt_bboxes_3d[0].tensor
gt_bboxes_3d[:, [3, 4]] = gt_bboxes_3d[:, [4, 3]]
gt_bboxes_3d[:, 6] = -gt_bboxes_3d[:, 6] - np.pi / 2
gt_bboxes_3d[:, 6] = limit_period(
gt_bboxes_3d[:, 6], period=np.pi * 2)
gt_bboxes_3d = LiDARInstance3DBoxes(gt_bboxes_3d, box_dim=9)
gt_bboxes_3d = [gt_bboxes_3d]
return gt_bboxes_3d
return batch_input_metas
......@@ -185,8 +185,8 @@ class GlobalRotScaleTransImage(BaseTransform):
num_view = len(results['lidar2cam'])
for view in range(num_view):
results['lidar2cam'][view] = (
torch.tensor(np.array(results['lidar2cam'][view])).float()
@ rot_mat_inv).numpy()
torch.tensor(np.array(results['lidar2cam'][view]).T).float()
@ rot_mat_inv).T.numpy()
return
......@@ -203,5 +203,7 @@ class GlobalRotScaleTransImage(BaseTransform):
num_view = len(results['lidar2cam'])
for view in range(num_view):
results['lidar2cam'][view] = (torch.tensor(
rot_mat_inv.T @ results['lidar2cam'][view]).float()).numpy()
rot_mat_inv.T
@ results['lidar2cam'][view].T).float()).T.numpy()
return
......@@ -2,7 +2,6 @@
import numpy as np
import torch
import mmdet3d
from mmdet3d.structures.bbox_3d.utils import limit_period
......@@ -11,19 +10,21 @@ def normalize_bbox(bboxes, pc_range):
cx = bboxes[..., 0:1]
cy = bboxes[..., 1:2]
cz = bboxes[..., 2:3]
w = bboxes[..., 3:4].log()
length = bboxes[..., 4:5].log()
h = bboxes[..., 5:6].log()
length = bboxes[..., 3:4].log()
width = bboxes[..., 4:5].log()
height = bboxes[..., 5:6].log()
rot = bboxes[..., 6:7]
rot = -bboxes[..., 6:7] - np.pi / 2
rot = limit_period(rot, period=np.pi * 2)
if bboxes.size(-1) > 7:
vx = bboxes[..., 7:8]
vy = bboxes[..., 8:9]
normalized_bboxes = torch.cat(
(cx, cy, w, length, cz, h, rot.sin(), rot.cos(), vx, vy), dim=-1)
(cx, cy, length, width, cz, height, rot.sin(), rot.cos(), vx, vy),
dim=-1)
else:
normalized_bboxes = torch.cat(
(cx, cy, w, length, cz, h, rot.sin(), rot.cos()), dim=-1)
(cx, cy, length, width, cz, height, rot.sin(), rot.cos()), dim=-1)
return normalized_bboxes
......@@ -33,6 +34,8 @@ def denormalize_bbox(normalized_bboxes, pc_range):
rot_cosine = normalized_bboxes[..., 7:8]
rot = torch.atan2(rot_sine, rot_cosine)
rot = -rot - np.pi / 2
rot = limit_period(rot, period=np.pi * 2)
# center in the bev
cx = normalized_bboxes[..., 0:1]
......@@ -40,30 +43,21 @@ def denormalize_bbox(normalized_bboxes, pc_range):
cz = normalized_bboxes[..., 4:5]
# size
w = normalized_bboxes[..., 2:3]
length = normalized_bboxes[..., 3:4]
h = normalized_bboxes[..., 5:6]
length = normalized_bboxes[..., 2:3]
width = normalized_bboxes[..., 3:4]
height = normalized_bboxes[..., 5:6]
w = w.exp()
width = width.exp()
length = length.exp()
h = h.exp()
height = height.exp()
if normalized_bboxes.size(-1) > 8:
# velocity
vx = normalized_bboxes[:, 8:9]
vy = normalized_bboxes[:, 9:10]
denormalized_bboxes = torch.cat(
[cx, cy, cz, w, length, h, rot, vx, vy], dim=-1)
[cx, cy, cz, length, width, height, rot, vx, vy], dim=-1)
else:
denormalized_bboxes = torch.cat([cx, cy, cz, w, length, h, rot],
dim=-1)
denormalized_bboxes = torch.cat(
[cx, cy, cz, length, width, height, rot], dim=-1)
if int(mmdet3d.__version__[0]) >= 1:
denormalized_bboxes_clone = denormalized_bboxes.clone()
denormalized_bboxes[:, 3] = denormalized_bboxes_clone[:, 4]
denormalized_bboxes[:, 4] = denormalized_bboxes_clone[:, 3]
# change yaw
denormalized_bboxes[:,
6] = -denormalized_bboxes_clone[:, 6] - np.pi / 2
denormalized_bboxes[:, 6] = limit_period(
denormalized_bboxes[:, 6], period=np.pi * 2)
return denormalized_bboxes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment