Unverified Commit 12c3b19d authored by SekiroRong's avatar SekiroRong Committed by GitHub
Browse files

[fix] fix bug in projects/PETR (#2212)

parent 8bf2f5a4
...@@ -8,14 +8,11 @@ ...@@ -8,14 +8,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
import numpy as np
import torch import torch
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
import mmdet3d
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import LiDARInstance3DBoxes, limit_period
from mmdet3d.structures.ops import bbox3d2result from mmdet3d.structures.ops import bbox3d2result
from .grid_mask import GridMask from .grid_mask import GridMask
...@@ -174,8 +171,6 @@ class PETR(MVXTwoStageDetector): ...@@ -174,8 +171,6 @@ class PETR(MVXTwoStageDetector):
gt_labels_3d = [gt.labels_3d for gt in batch_gt_instances_3d] gt_labels_3d = [gt.labels_3d for gt in batch_gt_instances_3d]
gt_bboxes_ignore = None gt_bboxes_ignore = None
gt_bboxes_3d = self.LidarBox3dVersionTransfrom(gt_bboxes_3d)
batch_img_metas = self.add_lidar2img(img, batch_img_metas) batch_img_metas = self.add_lidar2img(img, batch_img_metas)
img_feats = self.extract_feat(img=img, img_metas=batch_img_metas) img_feats = self.extract_feat(img=img, img_metas=batch_img_metas)
...@@ -265,8 +260,8 @@ class PETR(MVXTwoStageDetector): ...@@ -265,8 +260,8 @@ class PETR(MVXTwoStageDetector):
Returns: Returns:
batch_input_metas (list[dict]): Meta info with lidar2img added batch_input_metas (list[dict]): Meta info with lidar2img added
""" """
lidar2img_rts = []
for meta in batch_input_metas: for meta in batch_input_metas:
lidar2img_rts = []
# obtain lidar to image transformation matrix # obtain lidar to image transformation matrix
for i in range(len(meta['cam2img'])): for i in range(len(meta['cam2img'])):
lidar2cam_rt = torch.tensor(meta['lidar2cam'][i]).double() lidar2cam_rt = torch.tensor(meta['lidar2cam'][i]).double()
...@@ -281,19 +276,7 @@ class PETR(MVXTwoStageDetector): ...@@ -281,19 +276,7 @@ class PETR(MVXTwoStageDetector):
# and LoadMultiViewImageFromMultiSweepsFiles. # and LoadMultiViewImageFromMultiSweepsFiles.
lidar2img_rts.append(lidar2img_rt) lidar2img_rts.append(lidar2img_rt)
meta['lidar2img'] = lidar2img_rts meta['lidar2img'] = lidar2img_rts
meta['img_shape'] = [i.shape for i in img[0]] img_shape = meta['img_shape'][:3]
return batch_input_metas meta['img_shape'] = [img_shape] * len(img[0])
def LidarBox3dVersionTransfrom(self, gt_bboxes_3d): return batch_input_metas
if int(mmdet3d.__version__[0]) >= 1:
# Begin hack adaptation to mmdet3d v1.0 ####
gt_bboxes_3d = gt_bboxes_3d[0].tensor
gt_bboxes_3d[:, [3, 4]] = gt_bboxes_3d[:, [4, 3]]
gt_bboxes_3d[:, 6] = -gt_bboxes_3d[:, 6] - np.pi / 2
gt_bboxes_3d[:, 6] = limit_period(
gt_bboxes_3d[:, 6], period=np.pi * 2)
gt_bboxes_3d = LiDARInstance3DBoxes(gt_bboxes_3d, box_dim=9)
gt_bboxes_3d = [gt_bboxes_3d]
return gt_bboxes_3d
...@@ -185,8 +185,8 @@ class GlobalRotScaleTransImage(BaseTransform): ...@@ -185,8 +185,8 @@ class GlobalRotScaleTransImage(BaseTransform):
num_view = len(results['lidar2cam']) num_view = len(results['lidar2cam'])
for view in range(num_view): for view in range(num_view):
results['lidar2cam'][view] = ( results['lidar2cam'][view] = (
torch.tensor(np.array(results['lidar2cam'][view])).float() torch.tensor(np.array(results['lidar2cam'][view]).T).float()
@ rot_mat_inv).numpy() @ rot_mat_inv).T.numpy()
return return
...@@ -203,5 +203,7 @@ class GlobalRotScaleTransImage(BaseTransform): ...@@ -203,5 +203,7 @@ class GlobalRotScaleTransImage(BaseTransform):
num_view = len(results['lidar2cam']) num_view = len(results['lidar2cam'])
for view in range(num_view): for view in range(num_view):
results['lidar2cam'][view] = (torch.tensor( results['lidar2cam'][view] = (torch.tensor(
rot_mat_inv.T @ results['lidar2cam'][view]).float()).numpy() rot_mat_inv.T
@ results['lidar2cam'][view].T).float()).T.numpy()
return return
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import numpy as np import numpy as np
import torch import torch
import mmdet3d
from mmdet3d.structures.bbox_3d.utils import limit_period from mmdet3d.structures.bbox_3d.utils import limit_period
...@@ -11,19 +10,21 @@ def normalize_bbox(bboxes, pc_range): ...@@ -11,19 +10,21 @@ def normalize_bbox(bboxes, pc_range):
cx = bboxes[..., 0:1] cx = bboxes[..., 0:1]
cy = bboxes[..., 1:2] cy = bboxes[..., 1:2]
cz = bboxes[..., 2:3] cz = bboxes[..., 2:3]
w = bboxes[..., 3:4].log() length = bboxes[..., 3:4].log()
length = bboxes[..., 4:5].log() width = bboxes[..., 4:5].log()
h = bboxes[..., 5:6].log() height = bboxes[..., 5:6].log()
rot = bboxes[..., 6:7] rot = -bboxes[..., 6:7] - np.pi / 2
rot = limit_period(rot, period=np.pi * 2)
if bboxes.size(-1) > 7: if bboxes.size(-1) > 7:
vx = bboxes[..., 7:8] vx = bboxes[..., 7:8]
vy = bboxes[..., 8:9] vy = bboxes[..., 8:9]
normalized_bboxes = torch.cat( normalized_bboxes = torch.cat(
(cx, cy, w, length, cz, h, rot.sin(), rot.cos(), vx, vy), dim=-1) (cx, cy, length, width, cz, height, rot.sin(), rot.cos(), vx, vy),
dim=-1)
else: else:
normalized_bboxes = torch.cat( normalized_bboxes = torch.cat(
(cx, cy, w, length, cz, h, rot.sin(), rot.cos()), dim=-1) (cx, cy, length, width, cz, height, rot.sin(), rot.cos()), dim=-1)
return normalized_bboxes return normalized_bboxes
...@@ -33,6 +34,8 @@ def denormalize_bbox(normalized_bboxes, pc_range): ...@@ -33,6 +34,8 @@ def denormalize_bbox(normalized_bboxes, pc_range):
rot_cosine = normalized_bboxes[..., 7:8] rot_cosine = normalized_bboxes[..., 7:8]
rot = torch.atan2(rot_sine, rot_cosine) rot = torch.atan2(rot_sine, rot_cosine)
rot = -rot - np.pi / 2
rot = limit_period(rot, period=np.pi * 2)
# center in the bev # center in the bev
cx = normalized_bboxes[..., 0:1] cx = normalized_bboxes[..., 0:1]
...@@ -40,30 +43,21 @@ def denormalize_bbox(normalized_bboxes, pc_range): ...@@ -40,30 +43,21 @@ def denormalize_bbox(normalized_bboxes, pc_range):
cz = normalized_bboxes[..., 4:5] cz = normalized_bboxes[..., 4:5]
# size # size
w = normalized_bboxes[..., 2:3] length = normalized_bboxes[..., 2:3]
length = normalized_bboxes[..., 3:4] width = normalized_bboxes[..., 3:4]
h = normalized_bboxes[..., 5:6] height = normalized_bboxes[..., 5:6]
w = w.exp() width = width.exp()
length = length.exp() length = length.exp()
h = h.exp() height = height.exp()
if normalized_bboxes.size(-1) > 8: if normalized_bboxes.size(-1) > 8:
# velocity # velocity
vx = normalized_bboxes[:, 8:9] vx = normalized_bboxes[:, 8:9]
vy = normalized_bboxes[:, 9:10] vy = normalized_bboxes[:, 9:10]
denormalized_bboxes = torch.cat( denormalized_bboxes = torch.cat(
[cx, cy, cz, w, length, h, rot, vx, vy], dim=-1) [cx, cy, cz, length, width, height, rot, vx, vy], dim=-1)
else: else:
denormalized_bboxes = torch.cat([cx, cy, cz, w, length, h, rot], denormalized_bboxes = torch.cat(
dim=-1) [cx, cy, cz, length, width, height, rot], dim=-1)
if int(mmdet3d.__version__[0]) >= 1:
denormalized_bboxes_clone = denormalized_bboxes.clone()
denormalized_bboxes[:, 3] = denormalized_bboxes_clone[:, 4]
denormalized_bboxes[:, 4] = denormalized_bboxes_clone[:, 3]
# change yaw
denormalized_bboxes[:,
6] = -denormalized_bboxes_clone[:, 6] - np.pi / 2
denormalized_bboxes[:, 6] = limit_period(
denormalized_bboxes[:, 6], period=np.pi * 2)
return denormalized_bboxes return denormalized_bboxes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment