Commit a9dc86e9 authored by lishj6's avatar lishj6 🏸
Browse files

init_0905

parent 18eda5c1
import numpy as np
import mmcv
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
class LoadMultiViewImageFromFiles(object):
"""Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames.
Args:
to_float32 (bool, optional): Whether to convert the img to float32.
Defaults to False.
color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
"""
def __init__(self, to_float32=False, color_type="unchanged"):
self.to_float32 = to_float32
self.color_type = color_type
def __call__(self, results):
"""Call function to load multi-view image from files.
Args:
results (dict): Result dict containing multi-view image filenames.
Returns:
dict: The result dict containing the multi-view image data.
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
- img_shape (tuple[int]): Shape of multi-view image arrays.
- ori_shape (tuple[int]): Shape of original image arrays.
- pad_shape (tuple[int]): Shape of padded image arrays.
- scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images.
"""
filename = results["img_filename"]
# img is of shape (h, w, c, num_views)
img = np.stack(
[mmcv.imread(name, self.color_type) for name in filename], axis=-1
)
if self.to_float32:
img = img.astype(np.float32)
results["filename"] = filename
# unravel to list, see `DefaultFormatBundle` in formatting.py
# which will transpose each image separately and then stack into array
results["img"] = [img[..., i] for i in range(img.shape[-1])]
results["img_shape"] = img.shape
results["ori_shape"] = img.shape
# Set initial values for default meta_keys
results["pad_shape"] = img.shape
results["scale_factor"] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results["img_norm_cfg"] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False,
)
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(to_float32={self.to_float32}, "
repr_str += f"color_type='{self.color_type}')"
return repr_str
@PIPELINES.register_module()
class LoadPointsFromFile(object):
"""Load Points From File.
Load points from file.
Args:
coord_type (str): The type of coordinates of points cloud.
Available options includes:
- 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points.
Defaults to 6.
use_dim (list[int], optional): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height.
Defaults to False.
use_color (bool, optional): Whether to use color features.
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""
def __init__(
self,
coord_type,
load_dim=6,
use_dim=[0, 1, 2],
shift_height=False,
use_color=False,
file_client_args=dict(backend="disk"),
):
self.shift_height = shift_height
self.use_color = use_color
if isinstance(use_dim, int):
use_dim = list(range(use_dim))
assert (
max(use_dim) < load_dim
), f"Expect all used dimensions < {load_dim}, got {use_dim}"
assert coord_type in ["CAMERA", "LIDAR", "DEPTH"]
self.coord_type = coord_type
self.load_dim = load_dim
self.use_dim = use_dim
self.file_client_args = file_client_args.copy()
self.file_client = None
def _load_points(self, pts_filename):
"""Private function to load point clouds data.
Args:
pts_filename (str): Filename of point clouds data.
Returns:
np.ndarray: An array containing point clouds data.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
try:
pts_bytes = self.file_client.get(pts_filename)
points = np.frombuffer(pts_bytes, dtype=np.float32)
except ConnectionError:
mmcv.check_file_exist(pts_filename)
if pts_filename.endswith(".npy"):
points = np.load(pts_filename)
else:
points = np.fromfile(pts_filename, dtype=np.float32)
return points
def __call__(self, results):
"""Call function to load points data from file.
Args:
results (dict): Result dict containing point clouds data.
Returns:
dict: The result dict containing the point clouds data.
Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data.
"""
pts_filename = results["pts_filename"]
points = self._load_points(pts_filename)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
attribute_dims = None
if self.shift_height:
floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height
points = np.concatenate(
[points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1
)
attribute_dims = dict(height=3)
if self.use_color:
assert len(self.use_dim) >= 6
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(
color=[
points.shape[1] - 3,
points.shape[1] - 2,
points.shape[1] - 1,
]
)
)
results["points"] = points
return results
import numpy as np
import mmcv
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor
@PIPELINES.register_module()
class MultiScaleDepthMapGenerator(object):
def __init__(self, downsample=1, max_depth=60):
if not isinstance(downsample, (list, tuple)):
downsample = [downsample]
self.downsample = downsample
self.max_depth = max_depth
def __call__(self, input_dict):
points = input_dict["points"][..., :3, None]
gt_depth = []
for i, lidar2img in enumerate(input_dict["lidar2img"]):
H, W = input_dict["img_shape"][i][:2]
pts_2d = (
np.squeeze(lidar2img[:3, :3] @ points, axis=-1)
+ lidar2img[:3, 3]
)
pts_2d[:, :2] /= pts_2d[:, 2:3]
U = np.round(pts_2d[:, 0]).astype(np.int32)
V = np.round(pts_2d[:, 1]).astype(np.int32)
depths = pts_2d[:, 2]
mask = np.logical_and.reduce(
[
V >= 0,
V < H,
U >= 0,
U < W,
depths >= 0.1,
# depths <= self.max_depth,
]
)
V, U, depths = V[mask], U[mask], depths[mask]
sort_idx = np.argsort(depths)[::-1]
V, U, depths = V[sort_idx], U[sort_idx], depths[sort_idx]
depths = np.clip(depths, 0.1, self.max_depth)
for j, downsample in enumerate(self.downsample):
if len(gt_depth) < j + 1:
gt_depth.append([])
h, w = (int(H / downsample), int(W / downsample))
u = np.floor(U / downsample).astype(np.int32)
v = np.floor(V / downsample).astype(np.int32)
depth_map = np.ones([h, w], dtype=np.float32) * -1
depth_map[v, u] = depths
gt_depth[j].append(depth_map)
input_dict["gt_depth"] = [np.stack(x) for x in gt_depth]
return input_dict
@PIPELINES.register_module()
class NuScenesSparse4DAdaptor(object):
def __init(self):
pass
def __call__(self, input_dict):
input_dict["projection_mat"] = np.float32(
np.stack(input_dict["lidar2img"])
)
input_dict["image_wh"] = np.ascontiguousarray(
np.array(input_dict["img_shape"], dtype=np.float32)[:, :2][:, ::-1]
)
input_dict["T_global_inv"] = np.linalg.inv(input_dict["lidar2global"])
input_dict["T_global"] = input_dict["lidar2global"]
if "cam_intrinsic" in input_dict:
input_dict["cam_intrinsic"] = np.float32(
np.stack(input_dict["cam_intrinsic"])
)
input_dict["focal"] = input_dict["cam_intrinsic"][..., 0, 0]
# input_dict["focal"] = np.sqrt(
# np.abs(np.linalg.det(input_dict["cam_intrinsic"][:, :2, :2]))
# )
if "instance_inds" in input_dict:
input_dict["instance_id"] = input_dict["instance_inds"]
if "gt_bboxes_3d" in input_dict:
input_dict["gt_bboxes_3d"][:, 6] = self.limit_period(
input_dict["gt_bboxes_3d"][:, 6], offset=0.5, period=2 * np.pi
)
input_dict["gt_bboxes_3d"] = DC(
to_tensor(input_dict["gt_bboxes_3d"]).float()
)
if "gt_labels_3d" in input_dict:
input_dict["gt_labels_3d"] = DC(
to_tensor(input_dict["gt_labels_3d"]).long()
)
imgs = [img.transpose(2, 0, 1) for img in input_dict["img"]]
imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
input_dict["img"] = DC(to_tensor(imgs), stack=True)
return input_dict
def limit_period(
self, val: np.ndarray, offset: float = 0.5, period: float = np.pi
) -> np.ndarray:
limited_val = val - np.floor(val / period + offset) * period
return limited_val
@PIPELINES.register_module()
class InstanceNameFilter(object):
"""Filter GT objects by their names.
Args:
classes (list[str]): List of class names to be kept for training.
"""
def __init__(self, classes):
self.classes = classes
self.labels = list(range(len(self.classes)))
def __call__(self, input_dict):
"""Call function to filter objects by their names.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
keys are updated in the result dict.
"""
gt_labels_3d = input_dict["gt_labels_3d"]
gt_bboxes_mask = np.array(
[n in self.labels for n in gt_labels_3d], dtype=np.bool_
)
input_dict["gt_bboxes_3d"] = input_dict["gt_bboxes_3d"][gt_bboxes_mask]
input_dict["gt_labels_3d"] = input_dict["gt_labels_3d"][gt_bboxes_mask]
if "instance_inds" in input_dict:
input_dict["instance_inds"] = input_dict["instance_inds"][
gt_bboxes_mask
]
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(classes={self.classes})"
return repr_str
@PIPELINES.register_module()
class CircleObjectRangeFilter(object):
def __init__(
self, class_dist_thred=[52.5] * 5 + [31.5] + [42] * 3 + [31.5]
):
self.class_dist_thred = class_dist_thred
def __call__(self, input_dict):
gt_bboxes_3d = input_dict["gt_bboxes_3d"]
gt_labels_3d = input_dict["gt_labels_3d"]
dist = np.sqrt(
np.sum(gt_bboxes_3d[:, :2] ** 2, axis=-1)
)
mask = np.array([False] * len(dist))
for label_idx, dist_thred in enumerate(self.class_dist_thred):
mask = np.logical_or(
mask,
np.logical_and(gt_labels_3d == label_idx, dist <= dist_thred),
)
gt_bboxes_3d = gt_bboxes_3d[mask]
gt_labels_3d = gt_labels_3d[mask]
input_dict["gt_bboxes_3d"] = gt_bboxes_3d
input_dict["gt_labels_3d"] = gt_labels_3d
if "instance_inds" in input_dict:
input_dict["instance_inds"] = input_dict["instance_inds"][mask]
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f"(class_dist_thred={self.class_dist_thred})"
return repr_str
@PIPELINES.register_module()
class NormalizeMultiviewImage(object):
"""Normalize the image.
Added key is "img_norm_cfg".
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
results["img"] = [
mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
for img in results["img"]
]
results["img_norm_cfg"] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb
)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})"
return repr_str
from .group_sampler import DistributedGroupSampler
from .distributed_sampler import DistributedSampler
from .sampler import SAMPLER, build_sampler
from .group_in_batch_sampler import (
GroupInBatchSampler,
)
import math
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
from .sampler import SAMPLER
import pdb
import sys
class ForkedPdb(pdb.Pdb):
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
def set_trace():
ForkedPdb().set_trace(sys._getframe().f_back)
@SAMPLER.register_module()
class DistributedSampler(_DistributedSampler):
def __init__(
self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0
):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle
)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0
def __iter__(self):
# deterministically shuffle based on epoch
assert not self.shuffle
if "data_infos" in dir(self.dataset):
timestamps = [
x["timestamp"] / 1e6 for x in self.dataset.data_infos
]
vehicle_idx = [
x["lidar_path"].split("/")[-1][:4]
if "lidar_path" in x
else None
for x in self.dataset.data_infos
]
else:
timestamps = [
x["timestamp"] / 1e6
for x in self.dataset.datasets[0].data_infos
] * len(self.dataset.datasets)
vehicle_idx = [
x["lidar_path"].split("/")[-1][:4]
if "lidar_path" in x
else None
for x in self.dataset.datasets[0].data_infos
] * len(self.dataset.datasets)
sequence_splits = []
for i in range(len(timestamps)):
if i == 0 or (
abs(timestamps[i] - timestamps[i - 1]) > 4
or vehicle_idx[i] != vehicle_idx[i - 1]
):
sequence_splits.append([i])
else:
sequence_splits[-1].append(i)
indices = []
perfix_sum = 0
split_length = len(self.dataset) // self.num_replicas
for i in range(len(sequence_splits)):
if perfix_sum >= (self.rank + 1) * split_length:
break
elif perfix_sum >= self.rank * split_length:
indices.extend(sequence_splits[i])
perfix_sum += len(sequence_splits[i])
self.num_samples = len(indices)
return iter(indices)
# https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py
import itertools
import copy
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
from torch.utils.data.sampler import Sampler
# https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157
def sync_random_seed(seed=None, device="cuda"):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
rank, world_size = get_dist_info()
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
class GroupInBatchSampler(Sampler):
"""
Pardon this horrendous name. Basically, we want every sample to be from its own group.
If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on
its own group.
Shuffling is only done for group order, not done within groups.
"""
def __init__(
self,
dataset,
batch_size=1,
world_size=None,
rank=None,
seed=0,
skip_prob=0.5,
sequence_flip_prob=0.1,
):
_rank, _world_size = get_dist_info()
if world_size is None:
world_size = _world_size
if rank is None:
rank = _rank
self.dataset = dataset
self.batch_size = batch_size
self.world_size = world_size
self.rank = rank
self.seed = sync_random_seed(seed)
self.size = len(self.dataset)
assert hasattr(self.dataset, "flag")
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.groups_num = len(self.group_sizes)
self.global_batch_size = batch_size * world_size
assert self.groups_num >= self.global_batch_size
# Now, for efficiency, make a dict group_idx: List[dataset sample_idxs]
self.group_idx_to_sample_idxs = {
group_idx: np.where(self.flag == group_idx)[0].tolist()
for group_idx in range(self.groups_num)
}
# Get a generator per sample idx. Considering samples over all
# GPUs, each sample position has its own generator
self.group_indices_per_global_sample_idx = [
self._group_indices_per_global_sample_idx(
self.rank * self.batch_size + local_sample_idx
)
for local_sample_idx in range(self.batch_size)
]
# Keep track of a buffer of dataset sample idxs for each local sample idx
self.buffer_per_local_sample = [[] for _ in range(self.batch_size)]
self.aug_per_local_sample = [None for _ in range(self.batch_size)]
self.skip_prob = skip_prob
self.sequence_flip_prob = sequence_flip_prob
def _infinite_group_indices(self):
g = torch.Generator()
g.manual_seed(self.seed)
while True:
yield from torch.randperm(self.groups_num, generator=g).tolist()
def _group_indices_per_global_sample_idx(self, global_sample_idx):
yield from itertools.islice(
self._infinite_group_indices(),
global_sample_idx,
None,
self.global_batch_size,
)
def __iter__(self):
while True:
curr_batch = []
for local_sample_idx in range(self.batch_size):
skip = (
np.random.uniform() < self.skip_prob
and len(self.buffer_per_local_sample[local_sample_idx]) > 1
)
if len(self.buffer_per_local_sample[local_sample_idx]) == 0:
# Finished current group, refill with next group
# skip = False
new_group_idx = next(
self.group_indices_per_global_sample_idx[
local_sample_idx
]
)
self.buffer_per_local_sample[
local_sample_idx
] = copy.deepcopy(
self.group_idx_to_sample_idxs[new_group_idx]
)
if np.random.uniform() < self.sequence_flip_prob:
self.buffer_per_local_sample[
local_sample_idx
] = self.buffer_per_local_sample[local_sample_idx][
::-1
]
if self.dataset.keep_consistent_seq_aug:
self.aug_per_local_sample[
local_sample_idx
] = self.dataset.get_augmentation()
if not self.dataset.keep_consistent_seq_aug:
self.aug_per_local_sample[
local_sample_idx
] = self.dataset.get_augmentation()
if skip:
self.buffer_per_local_sample[local_sample_idx].pop(0)
curr_batch.append(
dict(
idx=self.buffer_per_local_sample[local_sample_idx].pop(
0
),
aug_config=self.aug_per_local_sample[local_sample_idx],
)
)
yield curr_batch
def __len__(self):
"""Length of base dataset."""
return self.size
def set_epoch(self, epoch):
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import math
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from .sampler import SAMPLER
import random
from IPython import embed
@SAMPLER.register_module()
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
seed (int, optional): random seed used to shuffle the sampler if
``shuffle=True``. This number should be identical across all
processes in the distributed group. Default: 0.
"""
def __init__(
self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0
):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.seed = seed if seed is not None else 0
assert hasattr(self.dataset, "flag")
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += (
int(
math.ceil(
self.group_sizes[i]
* 1.0
/ self.samples_per_gpu
/ self.num_replicas
)
)
* self.samples_per_gpu
)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice = indice[
list(torch.randperm(int(size), generator=g).numpy())
].tolist()
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas
)
) * self.samples_per_gpu * self.num_replicas - len(indice)
# pad indice
tmp = indice.copy()
for _ in range(extra // size):
indice.extend(tmp)
indice.extend(tmp[: extra % size])
indices.extend(indice)
assert len(indices) == self.total_size
indices = [
indices[j]
for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g
)
)
for j in range(
i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu
)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
from mmcv.utils.registry import Registry, build_from_cfg
SAMPLER = Registry("sampler")
def build_sampler(cfg, default_args):
return build_from_cfg(cfg, SAMPLER, default_args)
import copy
import cv2
import numpy as np
import torch
from projects.mmdet3d_plugin.core.box3d import *
def box3d_to_corners(box3d):
if isinstance(box3d, torch.Tensor):
box3d = box3d.detach().cpu().numpy()
corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
# use relative origin [0.5, 0.5, 0]
corners_norm = corners_norm - np.array([0.5, 0.5, 0.5])
corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3])
# rotate around z axis
rot_cos = np.cos(box3d[:, YAW])
rot_sin = np.sin(box3d[:, YAW])
rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1))
rot_mat[:, 0, 0] = rot_cos
rot_mat[:, 0, 1] = -rot_sin
rot_mat[:, 1, 0] = rot_sin
rot_mat[:, 1, 1] = rot_cos
corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1)
corners += box3d[:, None, :3]
return corners
def plot_rect3d_on_img(
img, num_rects, rect_corners, color=(0, 255, 0), thickness=1
):
"""Plot the boundary lines of 3D rectangular on 2D images.
Args:
img (numpy.array): The numpy array of image.
num_rects (int): Number of 3D rectangulars.
rect_corners (numpy.array): Coordinates of the corners of 3D
rectangulars. Should be in the shape of [num_rect, 8, 2].
color (tuple[int], optional): The color to draw bboxes.
Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
line_indices = (
(0, 1),
(0, 3),
(0, 4),
(1, 2),
(1, 5),
(3, 2),
(3, 7),
(4, 5),
(4, 7),
(2, 6),
(5, 6),
(6, 7),
)
h, w = img.shape[:2]
for i in range(num_rects):
corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32)
for start, end in line_indices:
if (
(corners[start, 1] >= h or corners[start, 1] < 0)
or (corners[start, 0] >= w or corners[start, 0] < 0)
) and (
(corners[end, 1] >= h or corners[end, 1] < 0)
or (corners[end, 0] >= w or corners[end, 0] < 0)
):
continue
if isinstance(color[0], int):
cv2.line(
img,
(corners[start, 0], corners[start, 1]),
(corners[end, 0], corners[end, 1]),
color,
thickness,
cv2.LINE_AA,
)
else:
cv2.line(
img,
(corners[start, 0], corners[start, 1]),
(corners[end, 0], corners[end, 1]),
color[i],
thickness,
cv2.LINE_AA,
)
return img.astype(np.uint8)
def draw_lidar_bbox3d_on_img(
bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1
):
"""Project the 3D bbox on 2D plane and draw on input image.
Args:
bboxes3d (:obj:`LiDARInstance3DBoxes`):
3d bbox in lidar coordinate system to visualize.
raw_img (numpy.array): The numpy array of image.
lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix
according to the camera intrinsic parameters.
img_metas (dict): Useless here.
color (tuple[int], optional): The color to draw bboxes.
Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
img = raw_img.copy()
# corners_3d = bboxes3d.corners
corners_3d = box3d_to_corners(bboxes3d)
num_bbox = corners_3d.shape[0]
pts_4d = np.concatenate(
[corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1
)
lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
if isinstance(lidar2img_rt, torch.Tensor):
lidar2img_rt = lidar2img_rt.cpu().numpy()
pts_2d = pts_4d @ lidar2img_rt.T
pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5)
pts_2d[:, 0] /= pts_2d[:, 2]
pts_2d[:, 1] /= pts_2d[:, 2]
imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2)
return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness)
def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4):
img = img.copy()
N = points.shape[0]
points = points.cpu().numpy()
lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
if isinstance(lidar2img_rt, torch.Tensor):
lidar2img_rt = lidar2img_rt.cpu().numpy()
pts_2d = (
np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1)
+ lidar2img_rt[:3, 3]
)
pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5)
pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3]
pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32)
for i in range(N):
for point in pts_2d[i]:
if isinstance(color[0], int):
color_tmp = color
else:
color_tmp = color[i]
cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1)
return img.astype(np.uint8)
def draw_lidar_bbox3d_on_bev(
bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3):
if isinstance(bev_size, (list, tuple)):
bev_h, bev_w = bev_size
else:
bev_h, bev_w = bev_size, bev_size
bev = np.zeros([bev_h, bev_w, 3])
marking_color = (127, 127, 127)
bev_resolution = bev_range / bev_h
for cir in range(int(bev_range / 2 / 10)):
cv2.circle(
bev,
(int(bev_h / 2), int(bev_w / 2)),
int((cir + 1) * 10 / bev_resolution),
marking_color,
thickness=thickness,
)
cv2.line(
bev,
(0, int(bev_h / 2)),
(bev_w, int(bev_h / 2)),
marking_color,
)
cv2.line(
bev,
(int(bev_w / 2), 0),
(int(bev_w / 2), bev_h),
marking_color,
)
if len(bboxes_3d) != 0:
bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][
..., [0, 1]
]
xs = bev_corners[..., 0] / bev_resolution + bev_w / 2
ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2
for obj_idx, (x, y) in enumerate(zip(xs, ys)):
for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)):
if isinstance(color[0], (list, tuple)):
tmp = color[obj_idx]
else:
tmp = color
cv2.line(
bev,
(int(x[p1]), int(y[p1])),
(int(x[p2]), int(y[p2])),
tmp,
thickness=thickness,
)
return bev.astype(np.uint8)
def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)):
vis_imgs = []
for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)):
vis_imgs.append(
draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color)
)
num_imgs = len(vis_imgs)
if num_imgs < 4 or num_imgs % 2 != 0:
vis_imgs = np.concatenate(vis_imgs, axis=1)
else:
vis_imgs = np.concatenate([
np.concatenate(vis_imgs[:num_imgs//2], axis=1),
np.concatenate(vis_imgs[num_imgs//2:], axis=1)
], axis=0)
bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color)
vis_imgs = np.concatenate([bev, vis_imgs], axis=1)
return vis_imgs
from .sparse4d import Sparse4D
from .sparse4d_head import Sparse4DHead
from .blocks import (
DeformableFeatureAggregation,
DenseDepthNet,
AsymmetricFFN,
)
from .instance_bank import InstanceBank
from .detection3d import (
SparseBox3DDecoder,
SparseBox3DTarget,
SparseBox3DRefinementModule,
SparseBox3DKeyPointsGenerator,
SparseBox3DEncoder,
)
__all__ = [
"Sparse4D",
"Sparse4DHead",
"DeformableFeatureAggregation",
"DenseDepthNet",
"AsymmetricFFN",
"InstanceBank",
"SparseBox3DDecoder",
"SparseBox3DTarget",
"SparseBox3DRefinementModule",
"SparseBox3DKeyPointsGenerator",
"SparseBox3DEncoder",
]
from abc import ABC, abstractmethod
__all__ = ["BaseTargetWithDenoising"]
class BaseTargetWithDenoising(ABC):
def __init__(self, num_dn_groups=0, num_temp_dn_groups=0):
super(BaseTargetWithDenoising, self).__init__()
self.num_dn_groups = num_dn_groups
self.num_temp_dn_groups = num_temp_dn_groups
self.dn_metas = None
@abstractmethod
def sample(self, cls_pred, box_pred, cls_target, box_target):
"""
Perform Hungarian matching between predictions and ground truth,
returning the matched ground truth corresponding to the predictions
along with the corresponding regression weights.
"""
def get_dn_anchors(self, cls_target, box_target, *args, **kwargs):
"""
Generate noisy instances for the current frame, with a total of
'self.num_dn_groups' groups.
"""
return None
def update_dn(self, instance_feature, anchor, *args, **kwargs):
"""
Insert the previously saved 'self.dn_metas' into the noisy instances
of the current frame.
"""
def cache_dn(
self,
dn_instance_feature,
dn_anchor,
dn_cls_target,
valid_mask,
dn_id_target,
):
"""
Randomly save information for 'self.num_temp_dn_groups' groups of
temporal noisy instances to 'self.dn_metas'.
"""
if self.num_temp_dn_groups < 0:
return
self.dn_metas = dict(dn_anchor=dn_anchor[:, : self.num_temp_dn_groups])
# Copyright (c) Horizon Robotics. All rights reserved.
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp.autocast_mode import autocast
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import Sequential, BaseModule
from mmcv.cnn.bricks.transformer import FFN
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (
ATTENTION,
PLUGIN_LAYERS,
FEEDFORWARD_NETWORK,
)
try:
from ..ops import deformable_aggregation_function as DAF
except:
DAF = None
__all__ = [
"DeformableFeatureAggregation",
"DenseDepthNet",
"AsymmetricFFN",
]
def linear_relu_ln(embed_dims, in_loops, out_loops, input_dims=None):
if input_dims is None:
input_dims = embed_dims
layers = []
for _ in range(out_loops):
for _ in range(in_loops):
layers.append(Linear(input_dims, embed_dims))
layers.append(nn.ReLU(inplace=True))
input_dims = embed_dims
layers.append(nn.LayerNorm(embed_dims))
return layers
@ATTENTION.register_module()
class DeformableFeatureAggregation(BaseModule):
def __init__(
self,
embed_dims: int = 256,
num_groups: int = 8,
num_levels: int = 4,
num_cams: int = 6,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
kps_generator: dict = None,
temporal_fusion_module=None,
use_temporal_anchor_embed=True,
use_deformable_func=False,
use_camera_embed=False,
residual_mode="add",
):
super(DeformableFeatureAggregation, self).__init__()
if embed_dims % num_groups != 0:
raise ValueError(
f"embed_dims must be divisible by num_groups, "
f"but got {embed_dims} and {num_groups}"
)
self.group_dims = int(embed_dims / num_groups)
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_groups = num_groups
self.num_cams = num_cams
self.use_temporal_anchor_embed = use_temporal_anchor_embed
if use_deformable_func:
assert DAF is not None, "deformable_aggregation needs to be set up."
self.use_deformable_func = use_deformable_func
self.attn_drop = attn_drop
self.residual_mode = residual_mode
self.proj_drop = nn.Dropout(proj_drop)
kps_generator["embed_dims"] = embed_dims
self.kps_generator = build_from_cfg(kps_generator, PLUGIN_LAYERS)
self.num_pts = self.kps_generator.num_pts
if temporal_fusion_module is not None:
if "embed_dims" not in temporal_fusion_module:
temporal_fusion_module["embed_dims"] = embed_dims
self.temp_module = build_from_cfg(
temporal_fusion_module, PLUGIN_LAYERS
)
else:
self.temp_module = None
self.output_proj = Linear(embed_dims, embed_dims)
if use_camera_embed:
self.camera_encoder = Sequential(
*linear_relu_ln(embed_dims, 1, 2, 12)
)
self.weights_fc = Linear(
embed_dims, num_groups * num_levels * self.num_pts
)
else:
self.camera_encoder = None
self.weights_fc = Linear(
embed_dims, num_groups * num_cams * num_levels * self.num_pts
)
def init_weight(self):
constant_init(self.weights_fc, val=0.0, bias=0.0)
xavier_init(self.output_proj, distribution="uniform", bias=0.0)
def forward(
self,
instance_feature: torch.Tensor,
anchor: torch.Tensor,
anchor_embed: torch.Tensor,
feature_maps: List[torch.Tensor],
metas: dict,
**kwargs: dict,
):
bs, num_anchor = instance_feature.shape[:2]
key_points = self.kps_generator(anchor, instance_feature)
weights = self._get_weights(instance_feature, anchor_embed, metas)
if self.use_deformable_func:
points_2d = (
self.project_points(
key_points,
metas["projection_mat"],
metas.get("image_wh"),
)
.permute(0, 2, 3, 1, 4)
.reshape(bs, num_anchor, self.num_pts, self.num_cams, 2)
)
weights = (
weights.permute(0, 1, 4, 2, 3, 5)
.contiguous()
.reshape(
bs,
num_anchor,
self.num_pts,
self.num_cams,
self.num_levels,
self.num_groups,
)
)
features = DAF(*feature_maps, points_2d, weights).reshape(
bs, num_anchor, self.embed_dims
)
else:
features = self.feature_sampling(
feature_maps,
key_points,
metas["projection_mat"],
metas.get("image_wh"),
)
features = self.multi_view_level_fusion(features, weights)
features = features.sum(dim=2) # fuse multi-point features
output = self.proj_drop(self.output_proj(features))
if self.residual_mode == "add":
output = output + instance_feature
elif self.residual_mode == "cat":
output = torch.cat([output, instance_feature], dim=-1)
return output
def _get_weights(self, instance_feature, anchor_embed, metas=None):
bs, num_anchor = instance_feature.shape[:2]
feature = instance_feature + anchor_embed
if self.camera_encoder is not None:
camera_embed = self.camera_encoder(
metas["projection_mat"][:, :, :3].reshape(
bs, self.num_cams, -1
)
)
feature = feature[:, :, None] + camera_embed[:, None]
weights = (
self.weights_fc(feature)
.reshape(bs, num_anchor, -1, self.num_groups)
.softmax(dim=-2)
.reshape(
bs,
num_anchor,
self.num_cams,
self.num_levels,
self.num_pts,
self.num_groups,
)
)
if self.training and self.attn_drop > 0:
mask = torch.rand(
bs, num_anchor, self.num_cams, 1, self.num_pts, 1
)
mask = mask.to(device=weights.device, dtype=weights.dtype)
weights = ((mask > self.attn_drop) * weights) / (
1 - self.attn_drop
)
return weights
@staticmethod
def project_points(key_points, projection_mat, image_wh=None):
bs, num_anchor, num_pts = key_points.shape[:3]
pts_extend = torch.cat(
[key_points, torch.ones_like(key_points[..., :1])], dim=-1
)
points_2d = torch.matmul(
projection_mat[:, :, None, None], pts_extend[:, None, ..., None]
).squeeze(-1)
points_2d = points_2d[..., :2] / torch.clamp(
points_2d[..., 2:3], min=1e-5
)
if image_wh is not None:
points_2d = points_2d / image_wh[:, :, None, None]
return points_2d
@staticmethod
def feature_sampling(
feature_maps: List[torch.Tensor],
key_points: torch.Tensor,
projection_mat: torch.Tensor,
image_wh: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_levels = len(feature_maps)
num_cams = feature_maps[0].shape[1]
bs, num_anchor, num_pts = key_points.shape[:3]
points_2d = DeformableFeatureAggregation.project_points(
key_points, projection_mat, image_wh
)
points_2d = points_2d * 2 - 1
points_2d = points_2d.flatten(end_dim=1)
features = []
for fm in feature_maps:
features.append(
torch.nn.functional.grid_sample(
fm.flatten(end_dim=1), points_2d
)
)
features = torch.stack(features, dim=1)
features = features.reshape(
bs, num_cams, num_levels, -1, num_anchor, num_pts
).permute(
0, 4, 1, 2, 5, 3
) # bs, num_anchor, num_cams, num_levels, num_pts, embed_dims
return features
def multi_view_level_fusion(
self,
features: torch.Tensor,
weights: torch.Tensor,
):
bs, num_anchor = weights.shape[:2]
features = weights[..., None] * features.reshape(
features.shape[:-1] + (self.num_groups, self.group_dims)
)
features = features.sum(dim=2).sum(dim=2)
features = features.reshape(
bs, num_anchor, self.num_pts, self.embed_dims
)
return features
@PLUGIN_LAYERS.register_module()
class DenseDepthNet(BaseModule):
def __init__(
self,
embed_dims=256,
num_depth_layers=1,
equal_focal=100,
max_depth=60,
loss_weight=1.0,
):
super().__init__()
self.embed_dims = embed_dims
self.equal_focal = equal_focal
self.num_depth_layers = num_depth_layers
self.max_depth = max_depth
self.loss_weight = loss_weight
self.depth_layers = nn.ModuleList()
for i in range(num_depth_layers):
self.depth_layers.append(
nn.Conv2d(embed_dims, 1, kernel_size=1, stride=1, padding=0)
)
def forward(self, feature_maps, focal=None, gt_depths=None):
if focal is None:
focal = self.equal_focal
else:
focal = focal.reshape(-1)
depths = []
for i, feat in enumerate(feature_maps[: self.num_depth_layers]):
depth = self.depth_layers[i](feat.flatten(end_dim=1).float()).exp()
depth = depth.transpose(0, -1) * focal / self.equal_focal
depth = depth.transpose(0, -1)
depths.append(depth)
if gt_depths is not None and self.training:
loss = self.loss(depths, gt_depths)
return loss
return depths
def loss(self, depth_preds, gt_depths):
loss = 0.0
for pred, gt in zip(depth_preds, gt_depths):
pred = pred.permute(0, 2, 3, 1).contiguous().reshape(-1)
gt = gt.reshape(-1)
fg_mask = torch.logical_and(
gt > 0.0, torch.logical_not(torch.isnan(pred))
)
gt = gt[fg_mask]
pred = pred[fg_mask]
pred = torch.clip(pred, 0.0, self.max_depth)
with autocast(enabled=False):
error = torch.abs(pred - gt).sum()
_loss = (
error
/ max(1.0, len(gt) * len(depth_preds))
* self.loss_weight
)
loss = loss + _loss
return loss
@FEEDFORWARD_NETWORK.register_module()
class AsymmetricFFN(BaseModule):
def __init__(
self,
in_channels=None,
pre_norm=None,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type="ReLU", inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True,
init_cfg=None,
**kwargs,
):
super(AsymmetricFFN, self).__init__(init_cfg)
assert num_fcs >= 2, (
"num_fcs should be no less " f"than 2. got {num_fcs}."
)
self.in_channels = in_channels
self.pre_norm = pre_norm
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
layers = []
if in_channels is None:
in_channels = embed_dims
if pre_norm is not None:
self.pre_norm = build_norm_layer(pre_norm, in_channels)[1]
for _ in range(num_fcs - 1):
layers.append(
Sequential(
Linear(in_channels, feedforward_channels),
self.activate,
nn.Dropout(ffn_drop),
)
)
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = (
build_dropout(dropout_layer)
if dropout_layer
else torch.nn.Identity()
)
self.add_identity = add_identity
if self.add_identity:
self.identity_fc = (
torch.nn.Identity()
if in_channels == embed_dims
else Linear(self.in_channels, embed_dims)
)
def forward(self, x, identity=None):
if self.pre_norm is not None:
x = self.pre_norm(x)
out = self.layers(x)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = x
identity = self.identity_fc(identity)
return identity + self.dropout_layer(out)
from .decoder import SparseBox3DDecoder
from .target import SparseBox3DTarget
from .detection3d_blocks import (
SparseBox3DRefinementModule,
SparseBox3DKeyPointsGenerator,
SparseBox3DEncoder,
)
from .losses import SparseBox3DLoss
# Copyright (c) Horizon Robotics. All rights reserved.
from typing import Optional
import torch
from mmdet.core.bbox.builder import BBOX_CODERS
from projects.mmdet3d_plugin.core.box3d import *
@BBOX_CODERS.register_module()
class SparseBox3DDecoder(object):
def __init__(
self,
num_output: int = 300,
score_threshold: Optional[float] = None,
sorted: bool = True,
):
super(SparseBox3DDecoder, self).__init__()
self.num_output = num_output
self.score_threshold = score_threshold
self.sorted = sorted
def decode_box(self, box):
yaw = torch.atan2(box[:, SIN_YAW], box[:, COS_YAW])
box = torch.cat(
[
box[:, [X, Y, Z]],
box[:, [W, L, H]].exp(),
yaw[:, None],
box[:, VX:],
],
dim=-1,
)
return box
def decode(
self,
cls_scores,
box_preds,
instance_id=None,
qulity=None,
output_idx=-1,
):
squeeze_cls = instance_id is not None
cls_scores = cls_scores[output_idx].sigmoid()
if squeeze_cls:
cls_scores, cls_ids = cls_scores.max(dim=-1)
cls_scores = cls_scores.unsqueeze(dim=-1)
box_preds = box_preds[output_idx]
bs, num_pred, num_cls = cls_scores.shape
cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
self.num_output, dim=1, sorted=self.sorted
)
if not squeeze_cls:
cls_ids = indices % num_cls
if self.score_threshold is not None:
mask = cls_scores >= self.score_threshold
if qulity is not None:
centerness = qulity[output_idx][..., CNS]
centerness = torch.gather(centerness, 1, indices // num_cls)
cls_scores_origin = cls_scores.clone()
cls_scores *= centerness.sigmoid()
cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
if not squeeze_cls:
cls_ids = torch.gather(cls_ids, 1, idx)
if self.score_threshold is not None:
mask = torch.gather(mask, 1, idx)
indices = torch.gather(indices, 1, idx)
output = []
for i in range(bs):
category_ids = cls_ids[i]
if squeeze_cls:
category_ids = category_ids[indices[i]]
scores = cls_scores[i]
box = box_preds[i, indices[i] // num_cls]
if self.score_threshold is not None:
category_ids = category_ids[mask[i]]
scores = scores[mask[i]]
box = box[mask[i]]
if qulity is not None:
scores_origin = cls_scores_origin[i]
if self.score_threshold is not None:
scores_origin = scores_origin[mask[i]]
box = self.decode_box(box)
output.append(
{
"boxes_3d": box.cpu(),
"scores_3d": scores.cpu(),
"labels_3d": category_ids.cpu(),
}
)
if qulity is not None:
output[-1]["cls_scores"] = scores_origin.cpu()
if instance_id is not None:
ids = instance_id[i, indices[i]]
if self.score_threshold is not None:
ids = ids[mask[i]]
output[-1]["instance_ids"] = ids
return output
import torch
import torch.nn as nn
import numpy as np
from mmcv.cnn import Linear, Scale, bias_init_with_prob
from mmcv.runner.base_module import Sequential, BaseModule
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.registry import (
PLUGIN_LAYERS,
POSITIONAL_ENCODING,
)
from projects.mmdet3d_plugin.core.box3d import *
from ..blocks import linear_relu_ln
__all__ = [
"SparseBox3DRefinementModule",
"SparseBox3DKeyPointsGenerator",
"SparseBox3DEncoder",
]
@POSITIONAL_ENCODING.register_module()
class SparseBox3DEncoder(BaseModule):
def __init__(
self,
embed_dims,
vel_dims=3,
mode="add",
output_fc=True,
in_loops=1,
out_loops=2,
):
super().__init__()
assert mode in ["add", "cat"]
self.embed_dims = embed_dims
self.vel_dims = vel_dims
self.mode = mode
def embedding_layer(input_dims, output_dims):
return nn.Sequential(
*linear_relu_ln(output_dims, in_loops, out_loops, input_dims)
)
if not isinstance(embed_dims, (list, tuple)):
embed_dims = [embed_dims] * 5
self.pos_fc = embedding_layer(3, embed_dims[0])
self.size_fc = embedding_layer(3, embed_dims[1])
self.yaw_fc = embedding_layer(2, embed_dims[2])
if vel_dims > 0:
self.vel_fc = embedding_layer(self.vel_dims, embed_dims[3])
if output_fc:
self.output_fc = embedding_layer(embed_dims[-1], embed_dims[-1])
else:
self.output_fc = None
def forward(self, box_3d: torch.Tensor):
pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]])
size_feat = self.size_fc(box_3d[..., [W, L, H]])
yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]])
if self.mode == "add":
output = pos_feat + size_feat + yaw_feat
elif self.mode == "cat":
output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1)
if self.vel_dims > 0:
vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims])
if self.mode == "add":
output = output + vel_feat
elif self.mode == "cat":
output = torch.cat([output, vel_feat], dim=-1)
if self.output_fc is not None:
output = self.output_fc(output)
return output
@PLUGIN_LAYERS.register_module()
class SparseBox3DRefinementModule(BaseModule):
def __init__(
self,
embed_dims=256,
output_dim=11,
num_cls=10,
normalize_yaw=False,
refine_yaw=False,
with_cls_branch=True,
with_quality_estimation=False,
):
super(SparseBox3DRefinementModule, self).__init__()
self.embed_dims = embed_dims
self.output_dim = output_dim
self.num_cls = num_cls
self.normalize_yaw = normalize_yaw
self.refine_yaw = refine_yaw
self.refine_state = [X, Y, Z, W, L, H]
if self.refine_yaw:
self.refine_state += [SIN_YAW, COS_YAW]
self.layers = nn.Sequential(
*linear_relu_ln(embed_dims, 2, 2),
Linear(self.embed_dims, self.output_dim),
Scale([1.0] * self.output_dim),
)
self.with_cls_branch = with_cls_branch
if with_cls_branch:
self.cls_layers = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(self.embed_dims, self.num_cls),
)
self.with_quality_estimation = with_quality_estimation
if with_quality_estimation:
self.quality_layers = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(self.embed_dims, 2),
)
def init_weight(self):
if self.with_cls_branch:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.cls_layers[-1].bias, bias_init)
def forward(
self,
instance_feature: torch.Tensor,
anchor: torch.Tensor,
anchor_embed: torch.Tensor,
time_interval: torch.Tensor = 1.0,
return_cls=True,
):
feature = instance_feature + anchor_embed
output = self.layers(feature)
output[..., self.refine_state] = (
output[..., self.refine_state] + anchor[..., self.refine_state]
)
if self.normalize_yaw:
output[..., [SIN_YAW, COS_YAW]] = torch.nn.functional.normalize(
output[..., [SIN_YAW, COS_YAW]], dim=-1
)
if self.output_dim > 8:
if not isinstance(time_interval, torch.Tensor):
time_interval = instance_feature.new_tensor(time_interval)
translation = torch.transpose(output[..., VX:], 0, -1)
velocity = torch.transpose(translation / time_interval, 0, -1)
output[..., VX:] = velocity + anchor[..., VX:]
if return_cls:
assert self.with_cls_branch, "Without classification layers !!!"
cls = self.cls_layers(instance_feature)
else:
cls = None
if return_cls and self.with_quality_estimation:
quality = self.quality_layers(feature)
else:
quality = None
return output, cls, quality
@PLUGIN_LAYERS.register_module()
class SparseBox3DKeyPointsGenerator(BaseModule):
def __init__(
self,
embed_dims=256,
num_learnable_pts=0,
fix_scale=None,
):
super(SparseBox3DKeyPointsGenerator, self).__init__()
self.embed_dims = embed_dims
self.num_learnable_pts = num_learnable_pts
if fix_scale is None:
fix_scale = ((0.0, 0.0, 0.0),)
self.fix_scale = nn.Parameter(
torch.tensor(fix_scale), requires_grad=False
)
self.num_pts = len(self.fix_scale) + num_learnable_pts
if num_learnable_pts > 0:
self.learnable_fc = Linear(self.embed_dims, num_learnable_pts * 3)
def init_weight(self):
if self.num_learnable_pts > 0:
xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)
def forward(
self,
anchor,
instance_feature=None,
T_cur2temp_list=None,
cur_timestamp=None,
temp_timestamps=None,
):
bs, num_anchor = anchor.shape[:2]
size = anchor[..., None, [W, L, H]].exp()
key_points = self.fix_scale * size
if self.num_learnable_pts > 0 and instance_feature is not None:
learnable_scale = (
self.learnable_fc(instance_feature)
.reshape(bs, num_anchor, self.num_learnable_pts, 3)
.sigmoid()
- 0.5
)
key_points = torch.cat(
[key_points, learnable_scale * size], dim=-2
)
rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3])
rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW]
rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW]
rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW]
rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW]
rotation_mat[:, :, 2, 2] = 1
key_points = torch.matmul(
rotation_mat[:, :, None], key_points[..., None]
).squeeze(-1)
key_points = key_points + anchor[..., None, [X, Y, Z]]
if (
cur_timestamp is None
or temp_timestamps is None
or T_cur2temp_list is None
or len(temp_timestamps) == 0
):
return key_points
temp_key_points_list = []
velocity = anchor[..., VX:]
for i, t_time in enumerate(temp_timestamps):
time_interval = cur_timestamp - t_time
translation = (
velocity
* time_interval.to(dtype=velocity.dtype)[:, None, None]
)
temp_key_points = key_points - translation[:, :, None]
T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype)
temp_key_points = (
T_cur2temp[:, None, None, :3]
@ torch.cat(
[
temp_key_points,
torch.ones_like(temp_key_points[..., :1]),
],
dim=-1,
).unsqueeze(-1)
)
temp_key_points = temp_key_points.squeeze(-1)
temp_key_points_list.append(temp_key_points)
return key_points, temp_key_points_list
@staticmethod
def anchor_projection(
anchor,
T_src2dst_list,
src_timestamp=None,
dst_timestamps=None,
time_intervals=None,
):
dst_anchors = []
for i in range(len(T_src2dst_list)):
vel = anchor[..., VX:]
vel_dim = vel.shape[-1]
T_src2dst = torch.unsqueeze(
T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
)
center = anchor[..., [X, Y, Z]]
if time_intervals is not None:
time_interval = time_intervals[i]
elif src_timestamp is not None and dst_timestamps is not None:
time_interval = (src_timestamp - dst_timestamps[i]).to(
dtype=vel.dtype
)
else:
time_interval = None
if time_interval is not None:
translation = vel.transpose(0, -1) * time_interval
translation = translation.transpose(0, -1)
center = center - translation
center = (
torch.matmul(
T_src2dst[..., :3, :3], center[..., None]
).squeeze(dim=-1)
+ T_src2dst[..., :3, 3]
)
size = anchor[..., [W, L, H]]
yaw = torch.matmul(
T_src2dst[..., :2, :2],
anchor[..., [COS_YAW, SIN_YAW], None],
).squeeze(-1)
vel = torch.matmul(
T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]
).squeeze(-1)
dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
# TODO: Fix bug
# index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim]
# index = torch.tensor(index, device=dst_anchor.device)
# index = torch.argsort(index)
# dst_anchor = dst_anchor.index_select(dim=-1, index=index)
dst_anchors.append(dst_anchor)
return dst_anchors
@staticmethod
def distance(anchor):
return torch.norm(anchor[..., :2], p=2, dim=-1)
import torch
import torch.nn as nn
from mmcv.utils import build_from_cfg
from mmdet.models.builder import LOSSES
from projects.mmdet3d_plugin.core.box3d import *
@LOSSES.register_module()
class SparseBox3DLoss(nn.Module):
def __init__(
self,
loss_box,
loss_centerness=None,
loss_yawness=None,
cls_allow_reverse=None,
):
super().__init__()
def build(cfg, registry):
if cfg is None:
return None
return build_from_cfg(cfg, registry)
self.loss_box = build(loss_box, LOSSES)
self.loss_cns = build(loss_centerness, LOSSES)
self.loss_yns = build(loss_yawness, LOSSES)
self.cls_allow_reverse = cls_allow_reverse
def forward(
self,
box,
box_target,
weight=None,
avg_factor=None,
suffix="",
quality=None,
cls_target=None,
**kwargs,
):
# Some categories do not distinguish between positive and negative
# directions. For example, barrier in nuScenes dataset.
if self.cls_allow_reverse is not None and cls_target is not None:
if_reverse = (
torch.nn.functional.cosine_similarity(
box_target[..., [SIN_YAW, COS_YAW]],
box[..., [SIN_YAW, COS_YAW]],
dim=-1,
)
< 0
)
if_reverse = (
torch.isin(
cls_target, cls_target.new_tensor(self.cls_allow_reverse)
)
& if_reverse
)
box_target[..., [SIN_YAW, COS_YAW]] = torch.where(
if_reverse[..., None],
-box_target[..., [SIN_YAW, COS_YAW]],
box_target[..., [SIN_YAW, COS_YAW]],
)
output = {}
box_loss = self.loss_box(
box, box_target, weight=weight, avg_factor=avg_factor
)
output[f"loss_box{suffix}"] = box_loss
if quality is not None:
cns = quality[..., CNS]
yns = quality[..., YNS].sigmoid()
cns_target = torch.norm(
box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1
)
cns_target = torch.exp(-cns_target)
cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor)
output[f"loss_cns{suffix}"] = cns_loss
yns_target = (
torch.nn.functional.cosine_similarity(
box_target[..., [SIN_YAW, COS_YAW]],
box[..., [SIN_YAW, COS_YAW]],
dim=-1,
)
> 0
)
yns_target = yns_target.float()
yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor)
output[f"loss_yns{suffix}"] = yns_loss
return output
import torch
import numpy as np
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from mmdet.core.bbox.builder import BBOX_SAMPLERS
from projects.mmdet3d_plugin.core.box3d import *
from ..base_target import BaseTargetWithDenoising
__all__ = ["SparseBox3DTarget"]
@BBOX_SAMPLERS.register_module()
class SparseBox3DTarget(BaseTargetWithDenoising):
def __init__(
self,
cls_weight=2.0,
alpha=0.25,
gamma=2,
eps=1e-12,
box_weight=0.25,
reg_weights=None,
cls_wise_reg_weights=None,
num_dn_groups=0,
dn_noise_scale=0.5,
max_dn_gt=32,
add_neg_dn=True,
num_temp_dn_groups=0,
):
super(SparseBox3DTarget, self).__init__(
num_dn_groups, num_temp_dn_groups
)
self.cls_weight = cls_weight
self.box_weight = box_weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
self.reg_weights = reg_weights
if self.reg_weights is None:
self.reg_weights = [1.0] * 8 + [0.0] * 2
self.cls_wise_reg_weights = cls_wise_reg_weights
self.dn_noise_scale = dn_noise_scale
self.max_dn_gt = max_dn_gt
self.add_neg_dn = add_neg_dn
def encode_reg_target(self, box_target, device=None):
outputs = []
for box in box_target:
output = torch.cat(
[
box[..., [X, Y, Z]],
box[..., [W, L, H]].log(),
torch.sin(box[..., YAW]).unsqueeze(-1),
torch.cos(box[..., YAW]).unsqueeze(-1),
box[..., YAW + 1 :],
],
dim=-1,
)
if device is not None:
output = output.to(device=device)
outputs.append(output)
return outputs
def sample(
self,
cls_pred,
box_pred,
cls_target,
box_target,
):
bs, num_pred, num_cls = cls_pred.shape
cls_cost = self._cls_cost(cls_pred, cls_target)
box_target = self.encode_reg_target(box_target, box_pred.device)
instance_reg_weights = []
for i in range(len(box_target)):
weights = torch.logical_not(box_target[i].isnan()).to(
dtype=box_target[i].dtype
)
if self.cls_wise_reg_weights is not None:
for cls, weight in self.cls_wise_reg_weights.items():
weights = torch.where(
(cls_target[i] == cls)[:, None],
weights.new_tensor(weight),
weights,
)
instance_reg_weights.append(weights)
box_cost = self._box_cost(box_pred, box_target, instance_reg_weights)
indices = []
for i in range(bs):
if cls_cost[i] is not None and box_cost[i] is not None:
cost = (cls_cost[i] + box_cost[i]).detach().cpu().numpy()
cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost)
assign = linear_sum_assignment(cost)
indices.append(
[cls_pred.new_tensor(x, dtype=torch.int64) for x in assign]
)
else:
indices.append([None, None])
output_cls_target = (
cls_target[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls
)
output_box_target = box_pred.new_zeros(box_pred.shape)
output_reg_weights = box_pred.new_zeros(box_pred.shape)
for i, (pred_idx, target_idx) in enumerate(indices):
if len(cls_target[i]) == 0:
continue
output_cls_target[i, pred_idx] = cls_target[i][target_idx]
output_box_target[i, pred_idx] = box_target[i][target_idx]
output_reg_weights[i, pred_idx] = instance_reg_weights[i][
target_idx
]
return output_cls_target, output_box_target, output_reg_weights
def _cls_cost(self, cls_pred, cls_target):
bs = cls_pred.shape[0]
cls_pred = cls_pred.sigmoid()
cost = []
for i in range(bs):
if len(cls_target[i]) > 0:
neg_cost = (
-(1 - cls_pred[i] + self.eps).log()
* (1 - self.alpha)
* cls_pred[i].pow(self.gamma)
)
pos_cost = (
-(cls_pred[i] + self.eps).log()
* self.alpha
* (1 - cls_pred[i]).pow(self.gamma)
)
cost.append(
(pos_cost[:, cls_target[i]] - neg_cost[:, cls_target[i]])
* self.cls_weight
)
else:
cost.append(None)
return cost
def _box_cost(self, box_pred, box_target, instance_reg_weights):
bs = box_pred.shape[0]
cost = []
for i in range(bs):
if len(box_target[i]) > 0:
cost.append(
torch.sum(
torch.abs(box_pred[i, :, None] - box_target[i][None])
* instance_reg_weights[i][None]
* box_pred.new_tensor(self.reg_weights),
dim=-1,
)
* self.box_weight
)
else:
cost.append(None)
return cost
def get_dn_anchors(self, cls_target, box_target, gt_instance_id=None):
if self.num_dn_groups <= 0:
return None
if self.num_temp_dn_groups <= 0:
gt_instance_id = None
if self.max_dn_gt > 0:
cls_target = [x[: self.max_dn_gt] for x in cls_target]
box_target = [x[: self.max_dn_gt] for x in box_target]
if gt_instance_id is not None:
gt_instance_id = [x[: self.max_dn_gt] for x in gt_instance_id]
max_dn_gt = max([len(x) for x in cls_target])
if max_dn_gt == 0:
return None
cls_target = torch.stack(
[
F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1)
for x in cls_target
]
)
box_target = self.encode_reg_target(box_target, cls_target.device)
box_target = torch.stack(
[F.pad(x, (0, 0, 0, max_dn_gt - x.shape[0])) for x in box_target]
)
box_target = torch.where(
cls_target[..., None] == -1, box_target.new_tensor(0), box_target
)
if gt_instance_id is not None:
gt_instance_id = torch.stack(
[
F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1)
for x in gt_instance_id
]
)
bs, num_gt, state_dims = box_target.shape
if self.num_dn_groups > 1:
cls_target = cls_target.tile(self.num_dn_groups, 1)
box_target = box_target.tile(self.num_dn_groups, 1, 1)
if gt_instance_id is not None:
gt_instance_id = gt_instance_id.tile(self.num_dn_groups, 1)
noise = torch.rand_like(box_target) * 2 - 1
noise *= box_target.new_tensor(self.dn_noise_scale)
dn_anchor = box_target + noise
if self.add_neg_dn:
noise_neg = torch.rand_like(box_target) + 1
flag = torch.where(
torch.rand_like(box_target) > 0.5,
noise_neg.new_tensor(1),
noise_neg.new_tensor(-1),
)
noise_neg *= flag
noise_neg *= box_target.new_tensor(self.dn_noise_scale)
dn_anchor = torch.cat([dn_anchor, box_target + noise_neg], dim=1)
num_gt *= 2
box_cost = self._box_cost(
dn_anchor, box_target, torch.ones_like(box_target)
)
dn_box_target = torch.zeros_like(dn_anchor)
dn_cls_target = -torch.ones_like(cls_target) * 3
if gt_instance_id is not None:
dn_id_target = -torch.ones_like(gt_instance_id)
if self.add_neg_dn:
dn_cls_target = torch.cat([dn_cls_target, dn_cls_target], dim=1)
if gt_instance_id is not None:
dn_id_target = torch.cat([dn_id_target, dn_id_target], dim=1)
for i in range(dn_anchor.shape[0]):
cost = box_cost[i].cpu().numpy()
anchor_idx, gt_idx = linear_sum_assignment(cost)
anchor_idx = dn_anchor.new_tensor(anchor_idx, dtype=torch.int64)
gt_idx = dn_anchor.new_tensor(gt_idx, dtype=torch.int64)
dn_box_target[i, anchor_idx] = box_target[i, gt_idx]
dn_cls_target[i, anchor_idx] = cls_target[i, gt_idx]
if gt_instance_id is not None:
dn_id_target[i, anchor_idx] = gt_instance_id[i, gt_idx]
dn_anchor = (
dn_anchor.reshape(self.num_dn_groups, bs, num_gt, state_dims)
.permute(1, 0, 2, 3)
.flatten(1, 2)
)
dn_box_target = (
dn_box_target.reshape(self.num_dn_groups, bs, num_gt, state_dims)
.permute(1, 0, 2, 3)
.flatten(1, 2)
)
dn_cls_target = (
dn_cls_target.reshape(self.num_dn_groups, bs, num_gt)
.permute(1, 0, 2)
.flatten(1)
)
if gt_instance_id is not None:
dn_id_target = (
dn_id_target.reshape(self.num_dn_groups, bs, num_gt)
.permute(1, 0, 2)
.flatten(1)
)
else:
dn_id_target = None
valid_mask = dn_cls_target >= 0
if self.add_neg_dn:
cls_target = (
torch.cat([cls_target, cls_target], dim=1)
.reshape(self.num_dn_groups, bs, num_gt)
.permute(1, 0, 2)
.flatten(1)
)
valid_mask = torch.logical_or(
valid_mask, ((cls_target >= 0) & (dn_cls_target == -3))
) # valid denotes the items is not from pad.
attn_mask = dn_box_target.new_ones(
num_gt * self.num_dn_groups, num_gt * self.num_dn_groups
)
for i in range(self.num_dn_groups):
start = num_gt * i
end = start + num_gt
attn_mask[start:end, start:end] = 0
attn_mask = attn_mask == 1
dn_cls_target = dn_cls_target.long()
return (
dn_anchor,
dn_box_target,
dn_cls_target,
attn_mask,
valid_mask,
dn_id_target,
)
def update_dn(
self,
instance_feature,
anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id_target,
num_noraml_anchor,
temporal_valid_mask,
):
bs, num_anchor = instance_feature.shape[:2]
if temporal_valid_mask is None:
self.dn_metas = None
if self.dn_metas is None or num_noraml_anchor >= num_anchor:
return (
instance_feature,
anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id_target,
)
# split instance_feature and anchor into non-dn and dn
num_dn = num_anchor - num_noraml_anchor
dn_instance_feature = instance_feature[:, -num_dn:]
dn_anchor = anchor[:, -num_dn:]
instance_feature = instance_feature[:, :num_noraml_anchor]
anchor = anchor[:, :num_noraml_anchor]
# reshape all dn metas from (bs,num_all_dn,xxx)
# to (bs, dn_group, num_dn_per_group, xxx)
num_dn_groups = self.num_dn_groups
num_dn = num_dn // num_dn_groups
dn_feat = dn_instance_feature.reshape(bs, num_dn_groups, num_dn, -1)
dn_anchor = dn_anchor.reshape(bs, num_dn_groups, num_dn, -1)
dn_reg_target = dn_reg_target.reshape(bs, num_dn_groups, num_dn, -1)
dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_dn)
valid_mask = valid_mask.reshape(bs, num_dn_groups, num_dn)
if dn_id_target is not None:
dn_id = dn_id_target.reshape(bs, num_dn_groups, num_dn)
# update temp_dn_metas by instance_id
temp_dn_feat = self.dn_metas["dn_instance_feature"]
_, num_temp_dn_groups, num_temp_dn = temp_dn_feat.shape[:3]
temp_dn_id = self.dn_metas["dn_id_target"]
# bs, num_temp_dn_groups, num_temp_dn, num_dn
match = temp_dn_id[..., None] == dn_id[:, :num_temp_dn_groups, None]
temp_reg_target = (
match[..., None] * dn_reg_target[:, :num_temp_dn_groups, None]
).sum(dim=3)
temp_cls_target = torch.where(
torch.all(torch.logical_not(match), dim=-1),
self.dn_metas["dn_cls_target"].new_tensor(-1),
self.dn_metas["dn_cls_target"],
)
temp_valid_mask = self.dn_metas["valid_mask"]
temp_dn_anchor = self.dn_metas["dn_anchor"]
# handle the misalignment the length of temp_dn to dn caused by the
# change of num_gt, then concat the temp_dn and dn
temp_dn_metas = [
temp_dn_feat,
temp_dn_anchor,
temp_reg_target,
temp_cls_target,
temp_valid_mask,
temp_dn_id,
]
dn_metas = [
dn_feat,
dn_anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id,
]
output = []
for i, (temp_meta, meta) in enumerate(zip(temp_dn_metas, dn_metas)):
if num_temp_dn < num_dn:
pad = (0, num_dn - num_temp_dn)
if temp_meta.dim() == 4:
pad = (0, 0) + pad
else:
assert temp_meta.dim() == 3
temp_meta = F.pad(temp_meta, pad, value=0)
else:
temp_meta = temp_meta[:, :, :num_dn]
mask = temporal_valid_mask[:, None, None]
if meta.dim() == 4:
mask = mask.unsqueeze(dim=-1)
temp_meta = torch.where(
mask, temp_meta, meta[:, :num_temp_dn_groups]
)
meta = torch.cat([temp_meta, meta[:, num_temp_dn_groups:]], dim=1)
meta = meta.flatten(1, 2)
output.append(meta)
output[0] = torch.cat([instance_feature, output[0]], dim=1)
output[1] = torch.cat([anchor, output[1]], dim=1)
return output
def cache_dn(
self,
dn_instance_feature,
dn_anchor,
dn_cls_target,
valid_mask,
dn_id_target,
):
if self.num_temp_dn_groups < 0:
return
num_dn_groups = self.num_dn_groups
bs, num_dn = dn_instance_feature.shape[:2]
num_temp_dn = num_dn // num_dn_groups
temp_group_mask = (
torch.randperm(num_dn_groups) < self.num_temp_dn_groups
)
temp_group_mask = temp_group_mask.to(device=dn_anchor.device)
dn_instance_feature = dn_instance_feature.detach().reshape(
bs, num_dn_groups, num_temp_dn, -1
)[:, temp_group_mask]
dn_anchor = dn_anchor.detach().reshape(
bs, num_dn_groups, num_temp_dn, -1
)[:, temp_group_mask]
dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_temp_dn)[
:, temp_group_mask
]
valid_mask = valid_mask.reshape(bs, num_dn_groups, num_temp_dn)[
:, temp_group_mask
]
if dn_id_target is not None:
dn_id_target = dn_id_target.reshape(
bs, num_dn_groups, num_temp_dn
)[:, temp_group_mask]
self.dn_metas = dict(
dn_instance_feature=dn_instance_feature,
dn_anchor=dn_anchor,
dn_cls_target=dn_cls_target,
valid_mask=valid_mask,
dn_id_target=dn_id_target,
)
This diff is collapsed.
This diff is collapsed.
# Copyright (c) Horizon Robotics. All rights reserved.
from inspect import signature
import torch
from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from mmdet.models import (
DETECTORS,
BaseDetector,
build_backbone,
build_head,
build_neck,
)
from .grid_mask import GridMask
try:
from ..ops import feature_maps_format
DAF_VALID = True
except:
DAF_VALID = False
__all__ = ["Sparse4D"]
@DETECTORS.register_module()
class Sparse4D(BaseDetector):
def __init__(
self,
img_backbone,
head,
img_neck=None,
init_cfg=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
use_grid_mask=True,
use_deformable_func=False,
depth_branch=None,
):
super(Sparse4D, self).__init__(init_cfg=init_cfg)
if pretrained is not None:
backbone.pretrained = pretrained
self.img_backbone = build_backbone(img_backbone)
self.img_backbone = self.img_backbone.to(device='cuda', memory_format=torch.channels_last)
if img_neck is not None:
self.img_neck = build_neck(img_neck)
self.head = build_head(head)
self.use_grid_mask = use_grid_mask
if use_deformable_func:
assert DAF_VALID, "deformable_aggregation needs to be set up."
self.use_deformable_func = use_deformable_func
if depth_branch is not None:
self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
else:
self.depth_branch = None
if use_grid_mask:
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
)
@auto_fp16(apply_to=("img",), out_fp32=True)
def extract_feat(self, img, return_depth=False, metas=None):
bs = img.shape[0]
if img.dim() == 5: # multi-view
num_cams = img.shape[1]
img = img.flatten(end_dim=1)
else:
num_cams = 1
img = img.to(memory_format=torch.channels_last)
if self.use_grid_mask:
img = self.grid_mask(img)
if "metas" in signature(self.img_backbone.forward).parameters:
feature_maps = self.img_backbone(img, num_cams, metas=metas)
else:
feature_maps = self.img_backbone(img)
if self.img_neck is not None:
feature_maps = list(self.img_neck(feature_maps))
for i, feat in enumerate(feature_maps):
feature_maps[i] = torch.reshape(
feat, (bs, num_cams) + feat.shape[1:]
)
if return_depth and self.depth_branch is not None:
depths = self.depth_branch(feature_maps, metas.get("focal"))
else:
depths = None
if self.use_deformable_func:
feature_maps = feature_maps_format(feature_maps)
if return_depth:
return feature_maps, depths
return feature_maps
@force_fp32(apply_to=("img",))
def forward(self, img, **data):
if self.training:
return self.forward_train(img, **data)
else:
return self.forward_test(img, **data)
def forward_train(self, img, **data):
feature_maps, depths = self.extract_feat(img, True, data)
model_outs = self.head(feature_maps, data)
output = self.head.loss(model_outs, data)
if depths is not None and "gt_depth" in data:
output["loss_dense_depth"] = self.depth_branch.loss(
depths, data["gt_depth"]
)
return output
def forward_test(self, img, **data):
if isinstance(img, list):
return self.aug_test(img, **data)
else:
return self.simple_test(img, **data)
def simple_test(self, img, **data):
feature_maps = self.extract_feat(img)
model_outs = self.head(feature_maps, data)
results = self.head.post_process(model_outs)
output = [{"img_bbox": result} for result in results]
return output
def aug_test(self, img, **data):
# fake test time augmentation
for key in data.keys():
if isinstance(data[key], list):
data[key] = data[key][0]
return self.simple_test(img[0], **data)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment