Unverified Commit c6c3c46d authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Feature] Add MonoFlex data augmentation label generation function (#1026)

* add pipeline

* fix typos

* fix unittest

* fix_typos

* fix_typos

* change pipeline to func

* refactor

* refine gen indices

* add device to gen indices

* change tunc_objs_handle to handle_objs

* fix comments

* add numpy and tensor time comparision

* fix

* add numba accelerate

* fix format

* remove unnecssary func

* update edge_indices docstrings

* fix some comments
parent 0f8181f1
# Copyright (c) OpenMMLab. All rights reserved.
from .clip_sigmoid import clip_sigmoid
from .edge_indices import get_edge_indices
from .gen_keypoints import get_keypoints
from .handle_objs import filter_outside_objs, handle_proj_objs
from .mlp import MLP
__all__ = ['clip_sigmoid', 'MLP']
__all__ = [
'clip_sigmoid', 'MLP', 'get_edge_indices', 'filter_outside_objs',
'handle_proj_objs', 'get_keypoints'
]
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
def get_edge_indices(img_metas,
step=1,
pad_mode='default',
dtype=np.float32,
device='cpu'):
"""Function to filter the objects label outside the image.
The edge_indices are generated using numpy on cpu rather
than on CUDA due to the latency issue. When batch size = 8,
this function with numpy array is ~8 times faster than that
with CUDA tensor (0.09s and 0.72s in 100 runs).
Args:
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
step (int, optional): Step size used for generateing
edge indices. Default: 1.
pad_mode (str, optional): Padding mode during data pipeline.
Default: 'default'.
dtype (torch.dtype, optional): Dtype of edge indices tensor.
Default: np.float32.
device (str, optional): Device of edge indices tensor.
Default: 'cpu'.
Returns:
list[Tensor]: Edge indices for each image in batch data.
"""
edge_indices_list = []
for i in range(len(img_metas)):
img_shape = img_metas[i]['img_shape']
h, w = img_shape[:2]
edge_indices = []
if pad_mode == 'default':
x_min = 0
y_min = 0
x_max, y_max = w - 1, h - 1
else:
raise NotImplementedError
# left
y = np.arange(y_min, y_max, step, dtype=dtype)
x = np.ones(len(y)) * x_min
edge_indices_edge = np.stack((x, y), axis=1)
edge_indices.append(edge_indices_edge)
# bottom
x = np.arange(x_min, x_max, step, dtype=dtype)
y = np.ones(len(x)) * y_max
edge_indices_edge = np.stack((x, y), axis=1)
edge_indices.append(edge_indices_edge)
# right
y = np.arange(y_max, y_min, -step, dtype=dtype)
x = np.ones(len(y)) * x_max
edge_indices_edge = np.stack((x, y), axis=1)
edge_indices.append(edge_indices_edge)
# top
x = np.arange(x_max, x_min, -step, dtype=dtype)
y = np.ones(len(x)) * y_min
edge_indices_edge = np.stack((x, y), axis=1)
edge_indices.append(edge_indices_edge)
edge_indices = \
np.concatenate([index for index in edge_indices], axis=0)
edge_indices = torch.from_numpy(edge_indices).to(device).long()
edge_indices_list.append(edge_indices)
return edge_indices_list
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.core.bbox import points_cam2img
def get_keypoints(gt_bboxes_3d_list,
centers2d_list,
img_metas,
use_local_coords=True):
"""Function to filter the objects label outside the image.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
shape (num_gt, 4).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
use_local_coords (bool, optional): Wheher to use local coordinates
for keypoints. Default: True.
Returns:
tuple[list[Tensor]]: It contains two elements, the first is the
keypoints for each projected 2D bbox in batch data. The second is
the visible mask of depth calculated by keypoints.
"""
assert len(gt_bboxes_3d_list) == len(centers2d_list)
bs = len(gt_bboxes_3d_list)
keypoints2d_list = []
keypoints_depth_mask_list = []
for i in range(bs):
gt_bboxes_3d = gt_bboxes_3d_list[i]
centers2d = centers2d_list[i]
img_shape = img_metas[i]['img_shape']
cam2img = img_metas[i]['cam2img']
h, w = img_shape[:2]
# (N, 8, 3)
corners3d = gt_bboxes_3d.corners
top_centers3d = torch.mean(corners3d[:, [0, 1, 4, 5], :], dim=1)
bot_centers3d = torch.mean(corners3d[:, [2, 3, 6, 7], :], dim=1)
# (N, 2, 3)
top_bot_centers3d = torch.stack((top_centers3d, bot_centers3d), dim=1)
keypoints3d = torch.cat((corners3d, top_bot_centers3d), dim=1)
# (N, 10, 2)
keypoints2d = points_cam2img(keypoints3d, cam2img)
# keypoints mask: keypoints must be inside
# the image and in front of the camera
keypoints_x_visible = (keypoints2d[..., 0] >= 0) & (
keypoints2d[..., 0] <= w - 1)
keypoints_y_visible = (keypoints2d[..., 1] >= 0) & (
keypoints2d[..., 1] <= h - 1)
keypoints_z_visible = (keypoints3d[..., -1] > 0)
# (N, 1O)
keypoints_visible = keypoints_x_visible & \
keypoints_y_visible & keypoints_z_visible
# center, diag-02, diag-13
keypoints_depth_valid = torch.stack(
(keypoints_visible[:, [8, 9]].all(dim=1),
keypoints_visible[:, [0, 3, 5, 6]].all(dim=1),
keypoints_visible[:, [1, 2, 4, 7]].all(dim=1)),
dim=1)
keypoints_visible = keypoints_visible.float()
if use_local_coords:
keypoints2d = torch.cat((keypoints2d - centers2d.unsqueeze(1),
keypoints_visible.unsqueeze(-1)),
dim=2)
else:
keypoints2d = torch.cat(
(keypoints2d, keypoints_visible.unsqueeze(-1)), dim=2)
keypoints2d_list.append(keypoints2d)
keypoints_depth_mask_list.append(keypoints_depth_valid)
return (keypoints2d_list, keypoints_depth_mask_list)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list,
gt_labels_3d_list, centers2d_list, img_metas):
"""Function to filter the objects label outside the image.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box,
each has shape (num_gt,).
gt_bboxes_3d_list (list[Tensor]): 3D Ground truth bboxes of each
image, each has shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of each
box, each has shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
each has shape (num_gt, 2).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
"""
bs = len(centers2d_list)
for i in range(bs):
centers2d = centers2d_list[i].clone()
img_shape = img_metas[i]['img_shape']
keep_inds = (centers2d[:, 0] > 0) & \
(centers2d[:, 0] < img_shape[1]) & \
(centers2d[:, 1] > 0) & \
(centers2d[:, 1] < img_shape[0])
centers2d_list[i] = centers2d[keep_inds]
gt_labels_list[i] = gt_labels_list[i][keep_inds]
gt_bboxes_list[i] = gt_bboxes_list[i][keep_inds]
gt_bboxes_3d_list[i].tensor = gt_bboxes_3d_list[i].tensor[keep_inds]
gt_labels_3d_list[i] = gt_labels_3d_list[i][keep_inds]
def get_centers2d_target(centers2d, centers, img_shape):
"""Function to get target centers2d.
Args:
centers2d (Tensor): Projected 3D centers onto 2D images.
centers (Tensor): Centers of 2d gt bboxes.
img_shape (tuple): Resized image shape.
Returns:
torch.Tensor: Projected 3D centers (centers2D) target.
"""
N = centers2d.shape[0]
h, w = img_shape[:2]
valid_intersects = centers2d.new_zeros((N, 2))
a = (centers[:, 1] - centers2d[:, 1]) / (centers[:, 0] - centers2d[:, 0])
b = centers[:, 1] - a * centers[:, 0]
left_y = b
right_y = (w - 1) * a + b
top_x = -b / a
bottom_x = (h - 1 - b) / a
left_coors = torch.stack((left_y.new_zeros(N, ), left_y), dim=1)
right_coors = torch.stack((right_y.new_full((N, ), w - 1), right_y), dim=1)
top_coors = torch.stack((top_x, top_x.new_zeros(N, )), dim=1)
bottom_coors = torch.stack((bottom_x, bottom_x.new_full((N, ), h - 1)),
dim=1)
intersects = torch.stack(
[left_coors, right_coors, top_coors, bottom_coors], dim=1)
intersects_x = intersects[:, :, 0]
intersects_y = intersects[:, :, 1]
inds = (intersects_x >= 0) & (intersects_x <=
w - 1) & (intersects_y >= 0) & (
intersects_y <= h - 1)
valid_intersects = intersects[inds].reshape(N, 2, 2)
dist = torch.norm(valid_intersects - centers2d.unsqueeze(1), dim=2)
min_idx = torch.argmin(dist, dim=1)
min_idx = min_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2)
centers2d_target = valid_intersects.gather(dim=1, index=min_idx).squeeze(1)
return centers2d_target
def handle_proj_objs(centers2d_list, gt_bboxes_list, img_metas):
"""Function to handle projected object centers2d, generate target
centers2d.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
shape (num_gt, 4).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
tuple[list[Tensor]]: It contains three elements. The first is the
target centers2d after handling the truncated objects. The second
is the offsets between target centers2d and round int dtype
centers2d,and the last is the truncation mask for each object in
batch data.
"""
bs = len(centers2d_list)
centers2d_target_list = []
trunc_mask_list = []
offsets2d_list = []
# for now, only pad mode that img is padded by right and
# bottom side is supported.
for i in range(bs):
centers2d = centers2d_list[i]
gt_bbox = gt_bboxes_list[i]
img_shape = img_metas[i]['img_shape']
centers2d_target = centers2d.clone()
inside_inds = (centers2d[:, 0] > 0) & \
(centers2d[:, 0] < img_shape[1]) & \
(centers2d[:, 1] > 0) & \
(centers2d[:, 1] < img_shape[0])
outside_inds = ~inside_inds
# if there are outside objects
if outside_inds.any():
centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
outside_centers2d = centers2d[outside_inds]
match_centers = centers[outside_inds]
target_outside_centers2d = get_centers2d_target(
outside_centers2d, match_centers, img_shape)
centers2d_target[outside_inds] = target_outside_centers2d
offsets2d = centers2d - centers2d_target.round().int()
trunc_mask = outside_inds
centers2d_target_list.append(centers2d_target)
trunc_mask_list.append(trunc_mask)
offsets2d_list.append(offsets2d)
return (centers2d_target_list, offsets2d_list, trunc_mask_list)
......@@ -4,6 +4,9 @@ import pytest
import torch
from mmdet3d.core import array_converter, draw_heatmap_gaussian, points_img2cam
from mmdet3d.core.bbox import CameraInstance3DBoxes
from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices,
get_keypoints, handle_proj_objs)
def test_gaussian():
......@@ -188,3 +191,94 @@ def test_points_img2cam():
expected_xyzs = torch.tensor([[-0.4864, -0.2155, 0.7576],
[-0.6299, -0.2796, 0.9813]])
assert torch.allclose(xyzs, expected_xyzs, atol=1e-3)
def test_generate_edge_indices():
img_metas = [dict(img_shape=[300, 400]), dict(img_shape=[500, 450])]
edge_indices_list = get_edge_indices(img_metas)
assert edge_indices_list[0].shape[0] == 1396
assert edge_indices_list[1].shape[0] == 1896
def test_truncation_hanlde():
centers2d_list = [
torch.tensor([[-99.86, 199.45], [499.50, 399.20], [201.20, 99.86]])
]
gt_bboxes_list = [
torch.tensor([[0.25, 99.8, 99.8, 199.6], [300.2, 250.1, 399.8, 299.6],
[100.2, 20.1, 300.8, 180.7]])
]
img_metas = [dict(img_shape=[300, 400])]
centers2d_target_list, offsets2d_list, trunc_mask_list = \
handle_proj_objs(centers2d_list, gt_bboxes_list, img_metas)
centers2d_target = torch.tensor([[0., 166.30435501], [379.03437877, 299.],
[201.2, 99.86]])
offsets2d = torch.tensor([[-99.86, 33.45], [120.5, 100.2], [0.2, -0.14]])
trunc_mask = torch.tensor([True, True, False])
assert torch.allclose(centers2d_target_list[0], centers2d_target)
assert torch.allclose(offsets2d_list[0], offsets2d, atol=1e-4)
assert torch.all(trunc_mask_list[0] == trunc_mask)
assert torch.allclose(
centers2d_target_list[0].round().int() + offsets2d_list[0],
centers2d_list[0])
def test_filter_outside_objs():
centers2d_list = [
torch.tensor([[-99.86, 199.45], [499.50, 399.20], [201.20, 99.86]]),
torch.tensor([[-47.86, 199.45], [410.50, 399.20], [401.20, 349.86]])
]
gt_bboxes_list = [
torch.rand([3, 4], dtype=torch.float32),
torch.rand([3, 4], dtype=torch.float32)
]
gt_bboxes_3d_list = [
CameraInstance3DBoxes(torch.rand([3, 7]), box_dim=7),
CameraInstance3DBoxes(torch.rand([3, 7]), box_dim=7)
]
gt_labels_list = [torch.tensor([0, 1, 2]), torch.tensor([2, 0, 0])]
gt_labels_3d_list = [torch.tensor([0, 1, 2]), torch.tensor([2, 0, 0])]
img_metas = [dict(img_shape=[300, 400]), dict(img_shape=[500, 450])]
filter_outside_objs(gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list,
gt_labels_3d_list, centers2d_list, img_metas)
assert len(centers2d_list[0]) == len(gt_bboxes_3d_list[0]) == \
len(gt_bboxes_list[0]) == len(gt_labels_3d_list[0]) == \
len(gt_labels_list[0]) == 1
assert len(centers2d_list[1]) == len(gt_bboxes_3d_list[1]) == \
len(gt_bboxes_list[1]) == len(gt_labels_3d_list[1]) == \
len(gt_labels_list[1]) == 2
def test_generate_keypoints():
centers2d_list = [
torch.tensor([[-99.86, 199.45], [499.50, 399.20], [201.20, 99.86]]),
torch.tensor([[-47.86, 199.45], [410.50, 399.20], [401.20, 349.86]])
]
gt_bboxes_3d_list = [
CameraInstance3DBoxes(torch.rand([3, 7])),
CameraInstance3DBoxes(torch.rand([3, 7]))
]
img_metas = [
dict(
cam2img=[[1260.8474446004698, 0.0, 807.968244525554, 40.1111],
[0.0, 1260.8474446004698, 495.3344268742088, 2.34422],
[0.0, 0.0, 1.0, 0.00333333], [0.0, 0.0, 0.0, 1.0]],
img_shape=(300, 400)) for i in range(2)
]
keypoints2d_list, keypoints_depth_mask_list = \
get_keypoints(gt_bboxes_3d_list, centers2d_list, img_metas)
assert keypoints2d_list[0].shape == (3, 10, 3)
assert keypoints_depth_mask_list[0].shape == (3, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment