Commit afe88104 authored by lishj6's avatar lishj6 🏸
Browse files

init0905

parent a48c4071
from abc import ABC, abstractmethod
__all__ = ["BaseTargetWithDenoising"]
class BaseTargetWithDenoising(ABC):
def __init__(self, num_dn_groups=0, num_temp_dn_groups=0):
super(BaseTargetWithDenoising, self).__init__()
self.num_dn_groups = num_dn_groups
self.num_temp_dn_groups = num_temp_dn_groups
self.dn_metas = None
@abstractmethod
def sample(self, cls_pred, box_pred, cls_target, box_target):
"""
Perform Hungarian matching between predictions and ground truth,
returning the matched ground truth corresponding to the predictions
along with the corresponding regression weights.
"""
def get_dn_anchors(self, cls_target, box_target, *args, **kwargs):
"""
Generate noisy instances for the current frame, with a total of
'self.num_dn_groups' groups.
"""
return None
def update_dn(self, instance_feature, anchor, *args, **kwargs):
"""
Insert the previously saved 'self.dn_metas' into the noisy instances
of the current frame.
"""
def cache_dn(
self,
dn_instance_feature,
dn_anchor,
dn_cls_target,
valid_mask,
dn_id_target,
):
"""
Randomly save information for 'self.num_temp_dn_groups' groups of
temporal noisy instances to 'self.dn_metas'.
"""
if self.num_temp_dn_groups < 0:
return
self.dn_metas = dict(dn_anchor=dn_anchor[:, : self.num_temp_dn_groups])
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp.autocast_mode import autocast
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import Sequential, BaseModule
from mmcv.cnn.bricks.transformer import FFN
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (
ATTENTION,
PLUGIN_LAYERS,
FEEDFORWARD_NETWORK,
)
try:
from ..ops import deformable_aggregation_function as DAF
except:
DAF = None
__all__ = [
"DeformableFeatureAggregation",
"DenseDepthNet",
"AsymmetricFFN",
]
def linear_relu_ln(embed_dims, in_loops, out_loops, input_dims=None):
if input_dims is None:
input_dims = embed_dims
layers = []
for _ in range(out_loops):
for _ in range(in_loops):
layers.append(Linear(input_dims, embed_dims))
layers.append(nn.ReLU(inplace=True))
input_dims = embed_dims
layers.append(nn.LayerNorm(embed_dims))
return layers
@ATTENTION.register_module()
class DeformableFeatureAggregation(BaseModule):
def __init__(
self,
embed_dims: int = 256,
num_groups: int = 8,
num_levels: int = 4,
num_cams: int = 6,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
kps_generator: dict = None,
temporal_fusion_module=None,
use_temporal_anchor_embed=True,
use_deformable_func=False,
use_camera_embed=False,
residual_mode="add",
):
super(DeformableFeatureAggregation, self).__init__()
if embed_dims % num_groups != 0:
raise ValueError(
f"embed_dims must be divisible by num_groups, "
f"but got {embed_dims} and {num_groups}"
)
self.group_dims = int(embed_dims / num_groups)
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_groups = num_groups
self.num_cams = num_cams
self.use_temporal_anchor_embed = use_temporal_anchor_embed
if use_deformable_func:
assert DAF is not None, "deformable_aggregation needs to be set up."
self.use_deformable_func = use_deformable_func
self.attn_drop = attn_drop
self.residual_mode = residual_mode
self.proj_drop = nn.Dropout(proj_drop)
kps_generator["embed_dims"] = embed_dims
self.kps_generator = build_from_cfg(kps_generator, PLUGIN_LAYERS)
self.num_pts = self.kps_generator.num_pts
if temporal_fusion_module is not None:
if "embed_dims" not in temporal_fusion_module:
temporal_fusion_module["embed_dims"] = embed_dims
self.temp_module = build_from_cfg(
temporal_fusion_module, PLUGIN_LAYERS
)
else:
self.temp_module = None
self.output_proj = Linear(embed_dims, embed_dims)
if use_camera_embed:
self.camera_encoder = Sequential(
*linear_relu_ln(embed_dims, 1, 2, 12)
)
self.weights_fc = Linear(
embed_dims, num_groups * num_levels * self.num_pts
)
else:
self.camera_encoder = None
self.weights_fc = Linear(
embed_dims, num_groups * num_cams * num_levels * self.num_pts
)
def init_weight(self):
constant_init(self.weights_fc, val=0.0, bias=0.0)
xavier_init(self.output_proj, distribution="uniform", bias=0.0)
def forward(
self,
instance_feature: torch.Tensor,
anchor: torch.Tensor,
anchor_embed: torch.Tensor,
feature_maps: List[torch.Tensor],
metas: dict,
**kwargs: dict,
):
bs, num_anchor = instance_feature.shape[:2]
key_points = self.kps_generator(anchor, instance_feature)
weights = self._get_weights(instance_feature, anchor_embed, metas)
if self.use_deformable_func:
points_2d = (
self.project_points(
key_points,
metas["projection_mat"],
metas.get("image_wh"),
)
.permute(0, 2, 3, 1, 4)
.reshape(bs, num_anchor, self.num_pts, self.num_cams, 2)
)
weights = (
weights.permute(0, 1, 4, 2, 3, 5)
.contiguous()
.reshape(
bs,
num_anchor,
self.num_pts,
self.num_cams,
self.num_levels,
self.num_groups,
)
)
features = DAF(*feature_maps, points_2d, weights).reshape(
bs, num_anchor, self.embed_dims
)
else:
features = self.feature_sampling(
feature_maps,
key_points,
metas["projection_mat"],
metas.get("image_wh"),
)
features = self.multi_view_level_fusion(features, weights)
features = features.sum(dim=2) # fuse multi-point features
output = self.proj_drop(self.output_proj(features))
if self.residual_mode == "add":
output = output + instance_feature
elif self.residual_mode == "cat":
output = torch.cat([output, instance_feature], dim=-1)
return output
def _get_weights(self, instance_feature, anchor_embed, metas=None):
bs, num_anchor = instance_feature.shape[:2]
feature = instance_feature + anchor_embed
if self.camera_encoder is not None:
camera_embed = self.camera_encoder(
metas["projection_mat"][:, :, :3].reshape(
bs, self.num_cams, -1
)
)
feature = feature[:, :, None] + camera_embed[:, None]
weights = (
self.weights_fc(feature)
.reshape(bs, num_anchor, -1, self.num_groups)
.softmax(dim=-2)
.reshape(
bs,
num_anchor,
self.num_cams,
self.num_levels,
self.num_pts,
self.num_groups,
)
)
if self.training and self.attn_drop > 0:
mask = torch.rand(
bs, num_anchor, self.num_cams, 1, self.num_pts, 1
)
mask = mask.to(device=weights.device, dtype=weights.dtype)
weights = ((mask > self.attn_drop) * weights) / (
1 - self.attn_drop
)
return weights
@staticmethod
def project_points(key_points, projection_mat, image_wh=None):
bs, num_anchor, num_pts = key_points.shape[:3]
pts_extend = torch.cat(
[key_points, torch.ones_like(key_points[..., :1])], dim=-1
)
points_2d = torch.matmul(
projection_mat[:, :, None, None], pts_extend[:, None, ..., None]
).squeeze(-1)
points_2d = points_2d[..., :2] / torch.clamp(
points_2d[..., 2:3], min=1e-5
)
if image_wh is not None:
points_2d = points_2d / image_wh[:, :, None, None]
return points_2d
@staticmethod
def feature_sampling(
feature_maps: List[torch.Tensor],
key_points: torch.Tensor,
projection_mat: torch.Tensor,
image_wh: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_levels = len(feature_maps)
num_cams = feature_maps[0].shape[1]
bs, num_anchor, num_pts = key_points.shape[:3]
points_2d = DeformableFeatureAggregation.project_points(
key_points, projection_mat, image_wh
)
points_2d = points_2d * 2 - 1
points_2d = points_2d.flatten(end_dim=1)
features = []
for fm in feature_maps:
features.append(
torch.nn.functional.grid_sample(
fm.flatten(end_dim=1), points_2d
)
)
features = torch.stack(features, dim=1)
features = features.reshape(
bs, num_cams, num_levels, -1, num_anchor, num_pts
).permute(
0, 4, 1, 2, 5, 3
) # bs, num_anchor, num_cams, num_levels, num_pts, embed_dims
return features
def multi_view_level_fusion(
self,
features: torch.Tensor,
weights: torch.Tensor,
):
bs, num_anchor = weights.shape[:2]
features = weights[..., None] * features.reshape(
features.shape[:-1] + (self.num_groups, self.group_dims)
)
features = features.sum(dim=2).sum(dim=2)
features = features.reshape(
bs, num_anchor, self.num_pts, self.embed_dims
)
return features
@PLUGIN_LAYERS.register_module()
class DenseDepthNet(BaseModule):
def __init__(
self,
embed_dims=256,
num_depth_layers=1,
equal_focal=100,
max_depth=60,
loss_weight=1.0,
):
super().__init__()
self.embed_dims = embed_dims
self.equal_focal = equal_focal
self.num_depth_layers = num_depth_layers
self.max_depth = max_depth
self.loss_weight = loss_weight
self.depth_layers = nn.ModuleList()
for i in range(num_depth_layers):
self.depth_layers.append(
nn.Conv2d(embed_dims, 1, kernel_size=1, stride=1, padding=0)
)
def forward(self, feature_maps, focal=None, gt_depths=None):
if focal is None:
focal = self.equal_focal
else:
focal = focal.reshape(-1)
depths = []
for i, feat in enumerate(feature_maps[: self.num_depth_layers]):
depth = self.depth_layers[i](feat.flatten(end_dim=1).float()).exp()
depth = depth.transpose(0, -1) * focal / self.equal_focal
depth = depth.transpose(0, -1)
depths.append(depth)
if gt_depths is not None and self.training:
loss = self.loss(depths, gt_depths)
return loss
return depths
def loss(self, depth_preds, gt_depths):
loss = 0.0
for pred, gt in zip(depth_preds, gt_depths):
pred = pred.permute(0, 2, 3, 1).contiguous().reshape(-1)
gt = gt.reshape(-1)
fg_mask = torch.logical_and(
gt > 0.0, torch.logical_not(torch.isnan(pred))
)
gt = gt[fg_mask]
pred = pred[fg_mask]
pred = torch.clip(pred, 0.0, self.max_depth)
with autocast(enabled=False):
error = torch.abs(pred - gt).sum()
_loss = (
error
/ max(1.0, len(gt) * len(depth_preds))
* self.loss_weight
)
loss = loss + _loss
return loss
@FEEDFORWARD_NETWORK.register_module()
class AsymmetricFFN(BaseModule):
def __init__(
self,
in_channels=None,
pre_norm=None,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type="ReLU", inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True,
init_cfg=None,
**kwargs,
):
super(AsymmetricFFN, self).__init__(init_cfg)
assert num_fcs >= 2, (
"num_fcs should be no less " f"than 2. got {num_fcs}."
)
self.in_channels = in_channels
self.pre_norm = pre_norm
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
layers = []
if in_channels is None:
in_channels = embed_dims
if pre_norm is not None:
self.pre_norm = build_norm_layer(pre_norm, in_channels)[1]
for _ in range(num_fcs - 1):
layers.append(
Sequential(
Linear(in_channels, feedforward_channels),
self.activate,
nn.Dropout(ffn_drop),
)
)
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = (
build_dropout(dropout_layer)
if dropout_layer
else torch.nn.Identity()
)
self.add_identity = add_identity
if self.add_identity:
self.identity_fc = (
torch.nn.Identity()
if in_channels == embed_dims
else Linear(self.in_channels, embed_dims)
)
def forward(self, x, identity=None):
if self.pre_norm is not None:
x = self.pre_norm(x)
out = self.layers(x)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = x
identity = self.identity_fc(identity)
return identity + self.dropout_layer(out)
from .decoder import SparseBox3DDecoder
from .target import SparseBox3DTarget
from .detection3d_blocks import (
SparseBox3DRefinementModule,
SparseBox3DKeyPointsGenerator,
SparseBox3DEncoder,
)
from .losses import SparseBox3DLoss
from .detection3d_head import Sparse4DHead
from typing import Optional
import torch
from mmdet.core.bbox.builder import BBOX_CODERS
from projects.mmdet3d_plugin.core.box3d import *
def decode_box(box):
yaw = torch.atan2(box[..., SIN_YAW], box[..., COS_YAW])
box = torch.cat(
[
box[..., [X, Y, Z]],
box[..., [W, L, H]].exp(),
yaw[..., None],
box[..., VX:],
],
dim=-1,
)
return box
@BBOX_CODERS.register_module()
class SparseBox3DDecoder(object):
def __init__(
self,
num_output: int = 300,
score_threshold: Optional[float] = None,
sorted: bool = True,
):
super(SparseBox3DDecoder, self).__init__()
self.num_output = num_output
self.score_threshold = score_threshold
self.sorted = sorted
def decode(
self,
cls_scores,
box_preds,
instance_id=None,
quality=None,
output_idx=-1,
):
squeeze_cls = instance_id is not None
cls_scores = cls_scores[output_idx].sigmoid()
if squeeze_cls:
cls_scores, cls_ids = cls_scores.max(dim=-1)
cls_scores = cls_scores.unsqueeze(dim=-1)
box_preds = box_preds[output_idx]
bs, num_pred, num_cls = cls_scores.shape
cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
self.num_output, dim=1, sorted=self.sorted
)
if not squeeze_cls:
cls_ids = indices % num_cls
if self.score_threshold is not None:
mask = cls_scores >= self.score_threshold
if quality[output_idx] is None:
quality = None
if quality is not None:
centerness = quality[output_idx][..., CNS]
centerness = torch.gather(centerness, 1, indices // num_cls)
cls_scores_origin = cls_scores.clone()
cls_scores *= centerness.sigmoid()
cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
if not squeeze_cls:
cls_ids = torch.gather(cls_ids, 1, idx)
if self.score_threshold is not None:
mask = torch.gather(mask, 1, idx)
indices = torch.gather(indices, 1, idx)
output = []
for i in range(bs):
category_ids = cls_ids[i]
if squeeze_cls:
category_ids = category_ids[indices[i]]
scores = cls_scores[i]
box = box_preds[i, indices[i] // num_cls]
if self.score_threshold is not None:
category_ids = category_ids[mask[i]]
scores = scores[mask[i]]
box = box[mask[i]]
if quality is not None:
scores_origin = cls_scores_origin[i]
if self.score_threshold is not None:
scores_origin = scores_origin[mask[i]]
box = decode_box(box)
output.append(
{
"boxes_3d": box.cpu(),
"scores_3d": scores.cpu(),
"labels_3d": category_ids.cpu(),
}
)
if quality is not None:
output[-1]["cls_scores"] = scores_origin.cpu()
if instance_id is not None:
ids = instance_id[i, indices[i]]
if self.score_threshold is not None:
ids = ids[mask[i]]
output[-1]["instance_ids"] = ids
return output
import torch
import torch.nn as nn
import numpy as np
from mmcv.cnn import Linear, Scale, bias_init_with_prob
from mmcv.runner.base_module import Sequential, BaseModule
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.registry import (
PLUGIN_LAYERS,
POSITIONAL_ENCODING,
)
from projects.mmdet3d_plugin.core.box3d import *
from ..blocks import linear_relu_ln
__all__ = [
"SparseBox3DRefinementModule",
"SparseBox3DKeyPointsGenerator",
"SparseBox3DEncoder",
]
@POSITIONAL_ENCODING.register_module()
class SparseBox3DEncoder(BaseModule):
def __init__(
self,
embed_dims,
vel_dims=3,
mode="add",
output_fc=True,
in_loops=1,
out_loops=2,
):
super().__init__()
assert mode in ["add", "cat"]
self.embed_dims = embed_dims
self.vel_dims = vel_dims
self.mode = mode
def embedding_layer(input_dims, output_dims):
return nn.Sequential(
*linear_relu_ln(output_dims, in_loops, out_loops, input_dims)
)
if not isinstance(embed_dims, (list, tuple)):
embed_dims = [embed_dims] * 5
self.pos_fc = embedding_layer(3, embed_dims[0])
self.size_fc = embedding_layer(3, embed_dims[1])
self.yaw_fc = embedding_layer(2, embed_dims[2])
if vel_dims > 0:
self.vel_fc = embedding_layer(self.vel_dims, embed_dims[3])
if output_fc:
self.output_fc = embedding_layer(embed_dims[-1], embed_dims[-1])
else:
self.output_fc = None
def forward(self, box_3d: torch.Tensor):
pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]])
size_feat = self.size_fc(box_3d[..., [W, L, H]])
yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]])
if self.mode == "add":
output = pos_feat + size_feat + yaw_feat
elif self.mode == "cat":
output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1)
if self.vel_dims > 0:
vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims])
if self.mode == "add":
output = output + vel_feat
elif self.mode == "cat":
output = torch.cat([output, vel_feat], dim=-1)
if self.output_fc is not None:
output = self.output_fc(output)
return output
@PLUGIN_LAYERS.register_module()
class SparseBox3DRefinementModule(BaseModule):
def __init__(
self,
embed_dims=256,
output_dim=11,
num_cls=10,
normalize_yaw=False,
refine_yaw=False,
with_cls_branch=True,
with_quality_estimation=False,
):
super(SparseBox3DRefinementModule, self).__init__()
self.embed_dims = embed_dims
self.output_dim = output_dim
self.num_cls = num_cls
self.normalize_yaw = normalize_yaw
self.refine_yaw = refine_yaw
self.refine_state = [X, Y, Z, W, L, H]
if self.refine_yaw:
self.refine_state += [SIN_YAW, COS_YAW]
self.layers = nn.Sequential(
*linear_relu_ln(embed_dims, 2, 2),
Linear(self.embed_dims, self.output_dim),
Scale([1.0] * self.output_dim),
)
self.with_cls_branch = with_cls_branch
if with_cls_branch:
self.cls_layers = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(self.embed_dims, self.num_cls),
)
self.with_quality_estimation = with_quality_estimation
if with_quality_estimation:
self.quality_layers = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(self.embed_dims, 2),
)
def init_weight(self):
if self.with_cls_branch:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.cls_layers[-1].bias, bias_init)
def forward(
self,
instance_feature: torch.Tensor,
anchor: torch.Tensor,
anchor_embed: torch.Tensor,
time_interval: torch.Tensor = 1.0,
return_cls=True,
):
feature = instance_feature + anchor_embed
output = self.layers(feature)
output[..., self.refine_state] = (
output[..., self.refine_state] + anchor[..., self.refine_state]
)
if self.normalize_yaw:
output[..., [SIN_YAW, COS_YAW]] = torch.nn.functional.normalize(
output[..., [SIN_YAW, COS_YAW]], dim=-1
)
if self.output_dim > 8:
if not isinstance(time_interval, torch.Tensor):
time_interval = instance_feature.new_tensor(time_interval)
translation = torch.transpose(output[..., VX:], 0, -1)
velocity = torch.transpose(translation / time_interval, 0, -1)
output[..., VX:] = velocity + anchor[..., VX:]
if return_cls:
assert self.with_cls_branch, "Without classification layers !!!"
cls = self.cls_layers(instance_feature)
else:
cls = None
if return_cls and self.with_quality_estimation:
quality = self.quality_layers(feature)
else:
quality = None
return output, cls, quality
@PLUGIN_LAYERS.register_module()
class SparseBox3DKeyPointsGenerator(BaseModule):
def __init__(
self,
embed_dims=256,
num_learnable_pts=0,
fix_scale=None,
):
super(SparseBox3DKeyPointsGenerator, self).__init__()
self.embed_dims = embed_dims
self.num_learnable_pts = num_learnable_pts
if fix_scale is None:
fix_scale = ((0.0, 0.0, 0.0),)
self.fix_scale = nn.Parameter(
torch.tensor(fix_scale), requires_grad=False
)
self.num_pts = len(self.fix_scale) + num_learnable_pts
if num_learnable_pts > 0:
self.learnable_fc = Linear(self.embed_dims, num_learnable_pts * 3)
def init_weight(self):
if self.num_learnable_pts > 0:
xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)
def forward(
self,
anchor,
instance_feature=None,
T_cur2temp_list=None,
cur_timestamp=None,
temp_timestamps=None,
):
bs, num_anchor = anchor.shape[:2]
size = anchor[..., None, [W, L, H]].exp()
key_points = self.fix_scale * size
if self.num_learnable_pts > 0 and instance_feature is not None:
learnable_scale = (
self.learnable_fc(instance_feature)
.reshape(bs, num_anchor, self.num_learnable_pts, 3)
.sigmoid()
- 0.5
)
key_points = torch.cat(
[key_points, learnable_scale * size], dim=-2
)
rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3])
rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW]
rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW]
rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW]
rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW]
rotation_mat[:, :, 2, 2] = 1
key_points = torch.matmul(
rotation_mat[:, :, None], key_points[..., None]
).squeeze(-1)
key_points = key_points + anchor[..., None, [X, Y, Z]]
if (
cur_timestamp is None
or temp_timestamps is None
or T_cur2temp_list is None
or len(temp_timestamps) == 0
):
return key_points
temp_key_points_list = []
velocity = anchor[..., VX:]
for i, t_time in enumerate(temp_timestamps):
time_interval = cur_timestamp - t_time
translation = (
velocity
* time_interval.to(dtype=velocity.dtype)[:, None, None]
)
temp_key_points = key_points - translation[:, :, None]
T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype)
temp_key_points = (
T_cur2temp[:, None, None, :3]
@ torch.cat(
[
temp_key_points,
torch.ones_like(temp_key_points[..., :1]),
],
dim=-1,
).unsqueeze(-1)
)
temp_key_points = temp_key_points.squeeze(-1)
temp_key_points_list.append(temp_key_points)
return key_points, temp_key_points_list
@staticmethod
def anchor_projection(
anchor,
T_src2dst_list,
src_timestamp=None,
dst_timestamps=None,
time_intervals=None,
):
dst_anchors = []
for i in range(len(T_src2dst_list)):
vel = anchor[..., VX:]
vel_dim = vel.shape[-1]
T_src2dst = torch.unsqueeze(
T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
)
center = anchor[..., [X, Y, Z]]
if time_intervals is not None:
time_interval = time_intervals[i]
elif src_timestamp is not None and dst_timestamps is not None:
time_interval = (src_timestamp - dst_timestamps[i]).to(
dtype=vel.dtype
)
else:
time_interval = None
if time_interval is not None:
translation = vel.transpose(0, -1) * time_interval
translation = translation.transpose(0, -1)
center = center - translation
center = (
torch.matmul(
T_src2dst[..., :3, :3], center[..., None]
).squeeze(dim=-1)
+ T_src2dst[..., :3, 3]
)
size = anchor[..., [W, L, H]]
yaw = torch.matmul(
T_src2dst[..., :2, :2],
anchor[..., [COS_YAW, SIN_YAW], None],
).squeeze(-1)
yaw = yaw[..., [1,0]]
vel = torch.matmul(
T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]
).squeeze(-1)
dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
dst_anchors.append(dst_anchor)
return dst_anchors
@staticmethod
def distance(anchor):
return torch.norm(anchor[..., :2], p=2, dim=-1)
from typing import List, Optional, Tuple, Union
import warnings
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.registry import (
ATTENTION,
PLUGIN_LAYERS,
POSITIONAL_ENCODING,
FEEDFORWARD_NETWORK,
NORM_LAYERS,
)
from mmcv.runner import BaseModule, force_fp32
from mmcv.utils import build_from_cfg
from mmdet.core.bbox.builder import BBOX_SAMPLERS
from mmdet.core.bbox.builder import BBOX_CODERS
from mmdet.models import HEADS, LOSSES
from mmdet.core import reduce_mean
from ..blocks import DeformableFeatureAggregation as DFG
__all__ = ["Sparse4DHead"]
@HEADS.register_module()
class Sparse4DHead(BaseModule):
def __init__(
self,
instance_bank: dict,
anchor_encoder: dict,
graph_model: dict,
norm_layer: dict,
ffn: dict,
deformable_model: dict,
refine_layer: dict,
num_decoder: int = 6,
num_single_frame_decoder: int = -1,
temp_graph_model: dict = None,
loss_cls: dict = None,
loss_reg: dict = None,
decoder: dict = None,
sampler: dict = None,
gt_cls_key: str = "gt_labels_3d",
gt_reg_key: str = "gt_bboxes_3d",
gt_id_key: str = "instance_id",
with_instance_id: bool = True,
task_prefix: str = 'det',
reg_weights: List = None,
operation_order: Optional[List[str]] = None,
cls_threshold_to_reg: float = -1,
dn_loss_weight: float = 5.0,
decouple_attn: bool = True,
init_cfg: dict = None,
**kwargs,
):
super(Sparse4DHead, self).__init__(init_cfg)
self.num_decoder = num_decoder
self.num_single_frame_decoder = num_single_frame_decoder
self.gt_cls_key = gt_cls_key
self.gt_reg_key = gt_reg_key
self.gt_id_key = gt_id_key
self.with_instance_id = with_instance_id
self.task_prefix = task_prefix
self.cls_threshold_to_reg = cls_threshold_to_reg
self.dn_loss_weight = dn_loss_weight
self.decouple_attn = decouple_attn
if reg_weights is None:
self.reg_weights = [1.0] * 10
else:
self.reg_weights = reg_weights
if operation_order is None:
operation_order = [
"temp_gnn",
"gnn",
"norm",
"deformable",
"norm",
"ffn",
"norm",
"refine",
] * num_decoder
# delete the 'gnn' and 'norm' layers in the first transformer blocks
operation_order = operation_order[3:]
self.operation_order = operation_order
# =========== build modules ===========
def build(cfg, registry):
if cfg is None:
return None
return build_from_cfg(cfg, registry)
self.instance_bank = build(instance_bank, PLUGIN_LAYERS)
self.anchor_encoder = build(anchor_encoder, POSITIONAL_ENCODING)
self.sampler = build(sampler, BBOX_SAMPLERS)
self.decoder = build(decoder, BBOX_CODERS)
self.loss_cls = build(loss_cls, LOSSES)
self.loss_reg = build(loss_reg, LOSSES)
self.op_config_map = {
"temp_gnn": [temp_graph_model, ATTENTION],
"gnn": [graph_model, ATTENTION],
"norm": [norm_layer, NORM_LAYERS],
"ffn": [ffn, FEEDFORWARD_NETWORK],
"deformable": [deformable_model, ATTENTION],
"refine": [refine_layer, PLUGIN_LAYERS],
}
self.layers = nn.ModuleList(
[
build(*self.op_config_map.get(op, [None, None]))
for op in self.operation_order
]
)
self.embed_dims = self.instance_bank.embed_dims
if self.decouple_attn:
self.fc_before = nn.Linear(
self.embed_dims, self.embed_dims * 2, bias=False
)
self.fc_after = nn.Linear(
self.embed_dims * 2, self.embed_dims, bias=False
)
else:
self.fc_before = nn.Identity()
self.fc_after = nn.Identity()
def init_weights(self):
for i, op in enumerate(self.operation_order):
if self.layers[i] is None:
continue
elif op != "refine":
for p in self.layers[i].parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if hasattr(m, "init_weight"):
m.init_weight()
def graph_model(
self,
index,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
**kwargs,
):
if self.decouple_attn:
query = torch.cat([query, query_pos], dim=-1)
if key is not None:
key = torch.cat([key, key_pos], dim=-1)
query_pos, key_pos = None, None
if value is not None:
value = self.fc_before(value)
return self.fc_after(
self.layers[index](
query,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
**kwargs,
)
)
def forward(
self,
feature_maps: Union[torch.Tensor, List],
metas: dict,
):
if isinstance(feature_maps, torch.Tensor):
feature_maps = [feature_maps]
batch_size = feature_maps[0].shape[0]
# ========= get instance info ============
if (
self.sampler.dn_metas is not None
and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size
):
self.sampler.dn_metas = None
(
instance_feature,
anchor,
temp_instance_feature,
temp_anchor,
time_interval,
) = self.instance_bank.get(
batch_size, metas, dn_metas=self.sampler.dn_metas
)
# ========= prepare for denosing training ============
# 1. get dn metas: noisy-anchors and corresponding GT
# 2. concat learnable instances and noisy instances
# 3. get attention mask
attn_mask = None
dn_metas = None
temp_dn_reg_target = None
if self.training and hasattr(self.sampler, "get_dn_anchors"):
if self.gt_id_key in metas["img_metas"][0]:
gt_instance_id = [
torch.from_numpy(x[self.gt_id_key]).cuda()
for x in metas["img_metas"]
]
else:
gt_instance_id = None
dn_metas = self.sampler.get_dn_anchors(
metas[self.gt_cls_key],
metas[self.gt_reg_key],
gt_instance_id,
)
if dn_metas is not None:
(
dn_anchor,
dn_reg_target,
dn_cls_target,
dn_attn_mask,
valid_mask,
dn_id_target,
) = dn_metas
num_dn_anchor = dn_anchor.shape[1]
if dn_anchor.shape[-1] != anchor.shape[-1]:
remain_state_dims = anchor.shape[-1] - dn_anchor.shape[-1]
dn_anchor = torch.cat(
[
dn_anchor,
dn_anchor.new_zeros(
batch_size, num_dn_anchor, remain_state_dims
),
],
dim=-1,
)
anchor = torch.cat([anchor, dn_anchor], dim=1)
instance_feature = torch.cat(
[
instance_feature,
instance_feature.new_zeros(
batch_size, num_dn_anchor, instance_feature.shape[-1]
),
],
dim=1,
)
num_instance = instance_feature.shape[1]
num_free_instance = num_instance - num_dn_anchor
attn_mask = anchor.new_ones(
(num_instance, num_instance), dtype=torch.bool
)
attn_mask[:num_free_instance, :num_free_instance] = False
attn_mask[num_free_instance:, num_free_instance:] = dn_attn_mask
anchor_embed = self.anchor_encoder(anchor)
if temp_anchor is not None:
temp_anchor_embed = self.anchor_encoder(temp_anchor)
else:
temp_anchor_embed = None
# =================== forward the layers ====================
prediction = []
classification = []
quality = []
for i, op in enumerate(self.operation_order):
if self.layers[i] is None:
continue
elif op == "temp_gnn":
instance_feature = self.graph_model(
i,
instance_feature,
temp_instance_feature,
temp_instance_feature,
query_pos=anchor_embed,
key_pos=temp_anchor_embed,
attn_mask=attn_mask
if temp_instance_feature is None
else None,
)
elif op == "gnn":
instance_feature = self.graph_model(
i,
instance_feature,
value=instance_feature,
query_pos=anchor_embed,
attn_mask=attn_mask,
)
elif op == "norm" or op == "ffn":
instance_feature = self.layers[i](instance_feature)
elif op == "deformable":
instance_feature = self.layers[i](
instance_feature,
anchor,
anchor_embed,
feature_maps,
metas,
)
elif op == "refine":
anchor, cls, qt = self.layers[i](
instance_feature,
anchor,
anchor_embed,
time_interval=time_interval,
return_cls=True,
)
prediction.append(anchor)
classification.append(cls)
quality.append(qt)
if len(prediction) == self.num_single_frame_decoder:
instance_feature, anchor = self.instance_bank.update(
instance_feature, anchor, cls
)
if (
dn_metas is not None
and self.sampler.num_temp_dn_groups > 0
and dn_id_target is not None
):
(
instance_feature,
anchor,
temp_dn_reg_target,
temp_dn_cls_target,
temp_valid_mask,
dn_id_target,
) = self.sampler.update_dn(
instance_feature,
anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id_target,
self.instance_bank.num_anchor,
self.instance_bank.mask,
)
anchor_embed = self.anchor_encoder(anchor)
if (
len(prediction) > self.num_single_frame_decoder
and temp_anchor_embed is not None
):
temp_anchor_embed = anchor_embed[
:, : self.instance_bank.num_temp_instances
]
else:
raise NotImplementedError(f"{op} is not supported.")
output = {}
# split predictions of learnable instances and noisy instances
if dn_metas is not None:
dn_classification = [
x[:, num_free_instance:] for x in classification
]
classification = [x[:, :num_free_instance] for x in classification]
dn_prediction = [x[:, num_free_instance:] for x in prediction]
prediction = [x[:, :num_free_instance] for x in prediction]
quality = [
x[:, :num_free_instance] if x is not None else None
for x in quality
]
output.update(
{
"dn_prediction": dn_prediction,
"dn_classification": dn_classification,
"dn_reg_target": dn_reg_target,
"dn_cls_target": dn_cls_target,
"dn_valid_mask": valid_mask,
}
)
if temp_dn_reg_target is not None:
output.update(
{
"temp_dn_reg_target": temp_dn_reg_target,
"temp_dn_cls_target": temp_dn_cls_target,
"temp_dn_valid_mask": temp_valid_mask,
"dn_id_target": dn_id_target,
}
)
dn_cls_target = temp_dn_cls_target
valid_mask = temp_valid_mask
dn_instance_feature = instance_feature[:, num_free_instance:]
dn_anchor = anchor[:, num_free_instance:]
instance_feature = instance_feature[:, :num_free_instance]
anchor_embed = anchor_embed[:, :num_free_instance]
anchor = anchor[:, :num_free_instance]
cls = cls[:, :num_free_instance]
# cache dn_metas for temporal denoising
self.sampler.cache_dn(
dn_instance_feature,
dn_anchor,
dn_cls_target,
valid_mask,
dn_id_target,
)
output.update(
{
"classification": classification,
"prediction": prediction,
"quality": quality,
"instance_feature": instance_feature,
"anchor_embed": anchor_embed,
}
)
# cache current instances for temporal modeling
self.instance_bank.cache(
instance_feature, anchor, cls, metas, feature_maps
)
if self.with_instance_id:
instance_id = self.instance_bank.get_instance_id(
cls, anchor, self.decoder.score_threshold
)
output["instance_id"] = instance_id
return output
@force_fp32(apply_to=("model_outs"))
def loss(self, model_outs, data, feature_maps=None):
# ===================== prediction losses ======================
cls_scores = model_outs["classification"]
reg_preds = model_outs["prediction"]
quality = model_outs["quality"]
output = {}
for decoder_idx, (cls, reg, qt) in enumerate(
zip(cls_scores, reg_preds, quality)
):
reg = reg[..., : len(self.reg_weights)]
cls_target, reg_target, reg_weights = self.sampler.sample(
cls,
reg,
data[self.gt_cls_key],
data[self.gt_reg_key],
)
reg_target = reg_target[..., : len(self.reg_weights)]
reg_target_full = reg_target.clone()
mask = torch.logical_not(torch.all(reg_target == 0, dim=-1))
mask_valid = mask.clone()
num_pos = max(
reduce_mean(torch.sum(mask).to(dtype=reg.dtype)), 1.0
)
if self.cls_threshold_to_reg > 0:
threshold = self.cls_threshold_to_reg
mask = torch.logical_and(
mask, cls.max(dim=-1).values.sigmoid() > threshold
)
cls = cls.flatten(end_dim=1)
cls_target = cls_target.flatten(end_dim=1)
cls_loss = self.loss_cls(cls, cls_target, avg_factor=num_pos)
mask = mask.reshape(-1)
reg_weights = reg_weights * reg.new_tensor(self.reg_weights)
reg_target = reg_target.flatten(end_dim=1)[mask]
reg = reg.flatten(end_dim=1)[mask]
reg_weights = reg_weights.flatten(end_dim=1)[mask]
reg_target = torch.where(
reg_target.isnan(), reg.new_tensor(0.0), reg_target
)
cls_target = cls_target[mask]
if qt is not None:
qt = qt.flatten(end_dim=1)[mask]
reg_loss = self.loss_reg(
reg,
reg_target,
weight=reg_weights,
avg_factor=num_pos,
prefix=f"{self.task_prefix}_",
suffix=f"_{decoder_idx}",
quality=qt,
cls_target=cls_target,
)
output[f"{self.task_prefix}_loss_cls_{decoder_idx}"] = cls_loss
output.update(reg_loss)
if "dn_prediction" not in model_outs:
return output
# ===================== denoising losses ======================
dn_cls_scores = model_outs["dn_classification"]
dn_reg_preds = model_outs["dn_prediction"]
(
dn_valid_mask,
dn_cls_target,
dn_reg_target,
dn_pos_mask,
reg_weights,
num_dn_pos,
) = self.prepare_for_dn_loss(model_outs)
for decoder_idx, (cls, reg) in enumerate(
zip(dn_cls_scores, dn_reg_preds)
):
if (
"temp_dn_valid_mask" in model_outs
and decoder_idx == self.num_single_frame_decoder
):
(
dn_valid_mask,
dn_cls_target,
dn_reg_target,
dn_pos_mask,
reg_weights,
num_dn_pos,
) = self.prepare_for_dn_loss(model_outs, prefix="temp_")
cls_loss = self.loss_cls(
cls.flatten(end_dim=1)[dn_valid_mask],
dn_cls_target,
avg_factor=num_dn_pos,
)
reg_loss = self.loss_reg(
reg.flatten(end_dim=1)[dn_valid_mask][dn_pos_mask][
..., : len(self.reg_weights)
],
dn_reg_target,
avg_factor=num_dn_pos,
weight=reg_weights,
prefix=f"{self.task_prefix}_",
suffix=f"_dn_{decoder_idx}",
)
output[f"{self.task_prefix}_loss_cls_dn_{decoder_idx}"] = cls_loss
output.update(reg_loss)
return output
def prepare_for_dn_loss(self, model_outs, prefix=""):
dn_valid_mask = model_outs[f"{prefix}dn_valid_mask"].flatten(end_dim=1)
dn_cls_target = model_outs[f"{prefix}dn_cls_target"].flatten(
end_dim=1
)[dn_valid_mask]
dn_reg_target = model_outs[f"{prefix}dn_reg_target"].flatten(
end_dim=1
)[dn_valid_mask][..., : len(self.reg_weights)]
dn_pos_mask = dn_cls_target >= 0
dn_reg_target = dn_reg_target[dn_pos_mask]
reg_weights = dn_reg_target.new_tensor(self.reg_weights)[None].tile(
dn_reg_target.shape[0], 1
)
num_dn_pos = max(
reduce_mean(torch.sum(dn_valid_mask).to(dtype=reg_weights.dtype)),
1.0,
)
return (
dn_valid_mask,
dn_cls_target,
dn_reg_target,
dn_pos_mask,
reg_weights,
num_dn_pos,
)
@force_fp32(apply_to=("model_outs"))
def post_process(self, model_outs, output_idx=-1):
return self.decoder.decode(
model_outs["classification"],
model_outs["prediction"],
model_outs.get("instance_id"),
model_outs.get("quality"),
output_idx=output_idx,
)
import torch
import torch.nn as nn
from mmcv.utils import build_from_cfg
from mmdet.models.builder import LOSSES
from projects.mmdet3d_plugin.core.box3d import *
@LOSSES.register_module()
class SparseBox3DLoss(nn.Module):
def __init__(
self,
loss_box,
loss_centerness=None,
loss_yawness=None,
cls_allow_reverse=None,
):
super().__init__()
def build(cfg, registry):
if cfg is None:
return None
return build_from_cfg(cfg, registry)
self.loss_box = build(loss_box, LOSSES)
self.loss_cns = build(loss_centerness, LOSSES)
self.loss_yns = build(loss_yawness, LOSSES)
self.cls_allow_reverse = cls_allow_reverse
def forward(
self,
box,
box_target,
weight=None,
avg_factor=None,
prefix="",
suffix="",
quality=None,
cls_target=None,
**kwargs,
):
# Some categories do not distinguish between positive and negative
# directions. For example, barrier in nuScenes dataset.
if self.cls_allow_reverse is not None and cls_target is not None:
if_reverse = (
torch.nn.functional.cosine_similarity(
box_target[..., [SIN_YAW, COS_YAW]],
box[..., [SIN_YAW, COS_YAW]],
dim=-1,
)
< 0
)
if_reverse = (
torch.isin(
cls_target, cls_target.new_tensor(self.cls_allow_reverse)
)
& if_reverse
)
box_target[..., [SIN_YAW, COS_YAW]] = torch.where(
if_reverse[..., None],
-box_target[..., [SIN_YAW, COS_YAW]],
box_target[..., [SIN_YAW, COS_YAW]],
)
output = {}
box_loss = self.loss_box(
box, box_target, weight=weight, avg_factor=avg_factor
)
output[f"{prefix}loss_box{suffix}"] = box_loss
if quality is not None:
cns = quality[..., CNS]
yns = quality[..., YNS].sigmoid()
cns_target = torch.norm(
box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1
)
cns_target = torch.exp(-cns_target)
cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor)
output[f"{prefix}loss_cns{suffix}"] = cns_loss
yns_target = (
torch.nn.functional.cosine_similarity(
box_target[..., [SIN_YAW, COS_YAW]],
box[..., [SIN_YAW, COS_YAW]],
dim=-1,
)
> 0
)
yns_target = yns_target.float()
yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor)
output[f"{prefix}loss_yns{suffix}"] = yns_loss
return output
import torch
import numpy as np
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from mmdet.core.bbox.builder import BBOX_SAMPLERS
from projects.mmdet3d_plugin.core.box3d import *
from ..base_target import BaseTargetWithDenoising
__all__ = ["SparseBox3DTarget"]
@BBOX_SAMPLERS.register_module()
class SparseBox3DTarget(BaseTargetWithDenoising):
def __init__(
self,
cls_weight=2.0,
alpha=0.25,
gamma=2,
eps=1e-12,
box_weight=0.25,
reg_weights=None,
cls_wise_reg_weights=None,
num_dn_groups=0,
dn_noise_scale=0.5,
max_dn_gt=32,
add_neg_dn=True,
num_temp_dn_groups=0,
):
super(SparseBox3DTarget, self).__init__(
num_dn_groups, num_temp_dn_groups
)
self.cls_weight = cls_weight
self.box_weight = box_weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
self.reg_weights = reg_weights
if self.reg_weights is None:
self.reg_weights = [1.0] * 8 + [0.0] * 2
self.cls_wise_reg_weights = cls_wise_reg_weights
self.dn_noise_scale = dn_noise_scale
self.max_dn_gt = max_dn_gt
self.add_neg_dn = add_neg_dn
def encode_reg_target(self, box_target, device=None):
outputs = []
for box in box_target:
output = torch.cat(
[
box[..., [X, Y, Z]],
box[..., [W, L, H]].log(),
torch.sin(box[..., YAW]).unsqueeze(-1),
torch.cos(box[..., YAW]).unsqueeze(-1),
box[..., YAW + 1 :],
],
dim=-1,
)
if device is not None:
output = output.to(device=device)
outputs.append(output)
return outputs
def sample(
self,
cls_pred,
box_pred,
cls_target,
box_target,
):
bs, num_pred, num_cls = cls_pred.shape
cls_cost = self._cls_cost(cls_pred, cls_target)
box_target = self.encode_reg_target(box_target, box_pred.device)
instance_reg_weights = []
for i in range(len(box_target)):
weights = torch.logical_not(box_target[i].isnan()).to(
dtype=box_target[i].dtype
)
if self.cls_wise_reg_weights is not None:
for cls, weight in self.cls_wise_reg_weights.items():
weights = torch.where(
(cls_target[i] == cls)[:, None],
weights.new_tensor(weight),
weights,
)
instance_reg_weights.append(weights)
box_cost = self._box_cost(box_pred, box_target, instance_reg_weights)
indices = []
for i in range(bs):
if cls_cost[i] is not None and box_cost[i] is not None:
cost = (cls_cost[i] + box_cost[i]).detach().cpu().numpy()
cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost)
assign = linear_sum_assignment(cost)
indices.append(
[cls_pred.new_tensor(x, dtype=torch.int64) for x in assign]
)
else:
indices.append([None, None])
output_cls_target = (
cls_target[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls
)
output_box_target = box_pred.new_zeros(box_pred.shape)
output_reg_weights = box_pred.new_zeros(box_pred.shape)
for i, (pred_idx, target_idx) in enumerate(indices):
if len(cls_target[i]) == 0:
continue
output_cls_target[i, pred_idx] = cls_target[i][target_idx]
output_box_target[i, pred_idx] = box_target[i][target_idx]
output_reg_weights[i, pred_idx] = instance_reg_weights[i][
target_idx
]
self.indices = indices
return output_cls_target, output_box_target, output_reg_weights
def _cls_cost(self, cls_pred, cls_target):
bs = cls_pred.shape[0]
cls_pred = cls_pred.sigmoid()
cost = []
for i in range(bs):
if len(cls_target[i]) > 0:
neg_cost = (
-(1 - cls_pred[i] + self.eps).log()
* (1 - self.alpha)
* cls_pred[i].pow(self.gamma)
)
pos_cost = (
-(cls_pred[i] + self.eps).log()
* self.alpha
* (1 - cls_pred[i]).pow(self.gamma)
)
cost.append(
(pos_cost[:, cls_target[i]] - neg_cost[:, cls_target[i]])
* self.cls_weight
)
else:
cost.append(None)
return cost
def _box_cost(self, box_pred, box_target, instance_reg_weights):
bs = box_pred.shape[0]
cost = []
for i in range(bs):
if len(box_target[i]) > 0:
cost.append(
torch.sum(
torch.abs(box_pred[i, :, None] - box_target[i][None])
* instance_reg_weights[i][None]
* box_pred.new_tensor(self.reg_weights),
dim=-1,
)
* self.box_weight
)
else:
cost.append(None)
return cost
def get_dn_anchors(self, cls_target, box_target, gt_instance_id=None):
if self.num_dn_groups <= 0:
return None
if self.num_temp_dn_groups <= 0:
gt_instance_id = None
if self.max_dn_gt > 0:
cls_target = [x[: self.max_dn_gt] for x in cls_target]
box_target = [x[: self.max_dn_gt] for x in box_target]
if gt_instance_id is not None:
gt_instance_id = [x[: self.max_dn_gt] for x in gt_instance_id]
max_dn_gt = max([len(x) for x in cls_target])
if max_dn_gt == 0:
return None
cls_target = torch.stack(
[
F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1)
for x in cls_target
]
)
box_target = self.encode_reg_target(box_target, cls_target.device)
box_target = torch.stack(
[F.pad(x, (0, 0, 0, max_dn_gt - x.shape[0])) for x in box_target]
)
box_target = torch.where(
cls_target[..., None] == -1, box_target.new_tensor(0), box_target
)
if gt_instance_id is not None:
gt_instance_id = torch.stack(
[
F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1)
for x in gt_instance_id
]
)
bs, num_gt, state_dims = box_target.shape
if self.num_dn_groups > 1:
cls_target = cls_target.tile(self.num_dn_groups, 1)
box_target = box_target.tile(self.num_dn_groups, 1, 1)
if gt_instance_id is not None:
gt_instance_id = gt_instance_id.tile(self.num_dn_groups, 1)
noise = torch.rand_like(box_target) * 2 - 1
noise *= box_target.new_tensor(self.dn_noise_scale)
dn_anchor = box_target + noise
if self.add_neg_dn:
noise_neg = torch.rand_like(box_target) + 1
flag = torch.where(
torch.rand_like(box_target) > 0.5,
noise_neg.new_tensor(1),
noise_neg.new_tensor(-1),
)
noise_neg *= flag
noise_neg *= box_target.new_tensor(self.dn_noise_scale)
dn_anchor = torch.cat([dn_anchor, box_target + noise_neg], dim=1)
num_gt *= 2
box_cost = self._box_cost(
dn_anchor, box_target, torch.ones_like(box_target)
)
dn_box_target = torch.zeros_like(dn_anchor)
dn_cls_target = -torch.ones_like(cls_target) * 3
if gt_instance_id is not None:
dn_id_target = -torch.ones_like(gt_instance_id)
if self.add_neg_dn:
dn_cls_target = torch.cat([dn_cls_target, dn_cls_target], dim=1)
if gt_instance_id is not None:
dn_id_target = torch.cat([dn_id_target, dn_id_target], dim=1)
for i in range(dn_anchor.shape[0]):
cost = box_cost[i].cpu().numpy()
anchor_idx, gt_idx = linear_sum_assignment(cost)
anchor_idx = dn_anchor.new_tensor(anchor_idx, dtype=torch.int64)
gt_idx = dn_anchor.new_tensor(gt_idx, dtype=torch.int64)
dn_box_target[i, anchor_idx] = box_target[i, gt_idx]
dn_cls_target[i, anchor_idx] = cls_target[i, gt_idx]
if gt_instance_id is not None:
dn_id_target[i, anchor_idx] = gt_instance_id[i, gt_idx]
dn_anchor = (
dn_anchor.reshape(self.num_dn_groups, bs, num_gt, state_dims)
.permute(1, 0, 2, 3)
.flatten(1, 2)
)
dn_box_target = (
dn_box_target.reshape(self.num_dn_groups, bs, num_gt, state_dims)
.permute(1, 0, 2, 3)
.flatten(1, 2)
)
dn_cls_target = (
dn_cls_target.reshape(self.num_dn_groups, bs, num_gt)
.permute(1, 0, 2)
.flatten(1)
)
if gt_instance_id is not None:
dn_id_target = (
dn_id_target.reshape(self.num_dn_groups, bs, num_gt)
.permute(1, 0, 2)
.flatten(1)
)
else:
dn_id_target = None
valid_mask = dn_cls_target >= 0
if self.add_neg_dn:
cls_target = (
torch.cat([cls_target, cls_target], dim=1)
.reshape(self.num_dn_groups, bs, num_gt)
.permute(1, 0, 2)
.flatten(1)
)
valid_mask = torch.logical_or(
valid_mask, ((cls_target >= 0) & (dn_cls_target == -3))
) # valid denotes the items is not from pad.
attn_mask = dn_box_target.new_ones(
num_gt * self.num_dn_groups, num_gt * self.num_dn_groups
)
for i in range(self.num_dn_groups):
start = num_gt * i
end = start + num_gt
attn_mask[start:end, start:end] = 0
attn_mask = attn_mask == 1
dn_cls_target = dn_cls_target.long()
return (
dn_anchor,
dn_box_target,
dn_cls_target,
attn_mask,
valid_mask,
dn_id_target,
)
def update_dn(
self,
instance_feature,
anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id_target,
num_noraml_anchor,
temporal_valid_mask,
):
bs, num_anchor = instance_feature.shape[:2]
if temporal_valid_mask is None:
self.dn_metas = None
if self.dn_metas is None or num_noraml_anchor >= num_anchor:
return (
instance_feature,
anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id_target,
)
# split instance_feature and anchor into non-dn and dn
num_dn = num_anchor - num_noraml_anchor
dn_instance_feature = instance_feature[:, -num_dn:]
dn_anchor = anchor[:, -num_dn:]
instance_feature = instance_feature[:, :num_noraml_anchor]
anchor = anchor[:, :num_noraml_anchor]
# reshape all dn metas from (bs,num_all_dn,xxx)
# to (bs, dn_group, num_dn_per_group, xxx)
num_dn_groups = self.num_dn_groups
num_dn = num_dn // num_dn_groups
dn_feat = dn_instance_feature.reshape(bs, num_dn_groups, num_dn, -1)
dn_anchor = dn_anchor.reshape(bs, num_dn_groups, num_dn, -1)
dn_reg_target = dn_reg_target.reshape(bs, num_dn_groups, num_dn, -1)
dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_dn)
valid_mask = valid_mask.reshape(bs, num_dn_groups, num_dn)
if dn_id_target is not None:
dn_id = dn_id_target.reshape(bs, num_dn_groups, num_dn)
# update temp_dn_metas by instance_id
temp_dn_feat = self.dn_metas["dn_instance_feature"]
_, num_temp_dn_groups, num_temp_dn = temp_dn_feat.shape[:3]
temp_dn_id = self.dn_metas["dn_id_target"]
# bs, num_temp_dn_groups, num_temp_dn, num_dn
match = temp_dn_id[..., None] == dn_id[:, :num_temp_dn_groups, None]
temp_reg_target = (
match[..., None] * dn_reg_target[:, :num_temp_dn_groups, None]
).sum(dim=3)
temp_cls_target = torch.where(
torch.all(torch.logical_not(match), dim=-1),
self.dn_metas["dn_cls_target"].new_tensor(-1),
self.dn_metas["dn_cls_target"],
)
temp_valid_mask = self.dn_metas["valid_mask"]
temp_dn_anchor = self.dn_metas["dn_anchor"]
# handle the misalignment the length of temp_dn to dn caused by the
# change of num_gt, then concat the temp_dn and dn
temp_dn_metas = [
temp_dn_feat,
temp_dn_anchor,
temp_reg_target,
temp_cls_target,
temp_valid_mask,
temp_dn_id,
]
dn_metas = [
dn_feat,
dn_anchor,
dn_reg_target,
dn_cls_target,
valid_mask,
dn_id,
]
output = []
for i, (temp_meta, meta) in enumerate(zip(temp_dn_metas, dn_metas)):
if num_temp_dn < num_dn:
pad = (0, num_dn - num_temp_dn)
if temp_meta.dim() == 4:
pad = (0, 0) + pad
else:
assert temp_meta.dim() == 3
temp_meta = F.pad(temp_meta, pad, value=0)
else:
temp_meta = temp_meta[:, :, :num_dn]
mask = temporal_valid_mask[:, None, None]
if meta.dim() == 4:
mask = mask.unsqueeze(dim=-1)
temp_meta = torch.where(
mask, temp_meta, meta[:, :num_temp_dn_groups]
)
meta = torch.cat([temp_meta, meta[:, num_temp_dn_groups:]], dim=1)
meta = meta.flatten(1, 2)
output.append(meta)
output[0] = torch.cat([instance_feature, output[0]], dim=1)
output[1] = torch.cat([anchor, output[1]], dim=1)
return output
def cache_dn(
self,
dn_instance_feature,
dn_anchor,
dn_cls_target,
valid_mask,
dn_id_target,
):
if self.num_temp_dn_groups < 0:
return
num_dn_groups = self.num_dn_groups
bs, num_dn = dn_instance_feature.shape[:2]
num_temp_dn = num_dn // num_dn_groups
temp_group_mask = (
torch.randperm(num_dn_groups) < self.num_temp_dn_groups
)
temp_group_mask = temp_group_mask.to(device=dn_anchor.device)
dn_instance_feature = dn_instance_feature.detach().reshape(
bs, num_dn_groups, num_temp_dn, -1
)[:, temp_group_mask]
dn_anchor = dn_anchor.detach().reshape(
bs, num_dn_groups, num_temp_dn, -1
)[:, temp_group_mask]
dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_temp_dn)[
:, temp_group_mask
]
valid_mask = valid_mask.reshape(bs, num_dn_groups, num_temp_dn)[
:, temp_group_mask
]
if dn_id_target is not None:
dn_id_target = dn_id_target.reshape(
bs, num_dn_groups, num_temp_dn
)[:, temp_group_mask]
self.dn_metas = dict(
dn_instance_feature=dn_instance_feature,
dn_anchor=dn_anchor,
dn_cls_target=dn_cls_target,
valid_mask=valid_mask,
dn_id_target=dn_id_target,
)
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
class Grid(object):
def __init__(
self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
):
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
def set_prob(self, epoch, max_epoch):
self.prob = self.st_prob * epoch / max_epoch
def __call__(self, img, label):
if np.random.rand() > self.prob:
return img, label
h = img.size(1)
w = img.size(2)
self.d1 = 2
self.d2 = min(h, w)
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(self.d1, self.d2)
if self.ratio == 1:
self.l = np.random.randint(1, d)
else:
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[
(hh - h) // 2 : (hh - h) // 2 + h,
(ww - w) // 2 : (ww - w) // 2 + w,
]
mask = torch.from_numpy(mask).float()
if self.mode == 1:
mask = 1 - mask
mask = mask.expand_as(img)
if self.offset:
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
offset = (1 - mask) * offset
img = img * mask + offset
else:
img = img * mask
return img, label
class GridMask(nn.Module):
def __init__(
self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
):
super(GridMask, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
def set_prob(self, epoch, max_epoch):
self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5
def forward(self, x):
if np.random.rand() > self.prob or not self.training:
return x
n, c, h, w = x.size()
x = x.view(-1, h, w)
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(2, h)
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[
(hh - h) // 2 : (hh - h) // 2 + h,
(ww - w) // 2 : (ww - w) // 2 + w,
]
mask = torch.from_numpy(mask.copy()).float().cuda()
if self.mode == 1:
mask = 1 - mask
mask = mask.expand_as(x)
if self.offset:
offset = (
torch.from_numpy(2 * (np.random.rand(h, w) - 0.5))
.float()
.cuda()
)
x = x * mask + offset * (1 - mask)
else:
x = x * mask
return x.view(n, c, h, w)
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
__all__ = ["InstanceBank"]
def topk(confidence, k, *inputs):
bs, N = confidence.shape[:2]
confidence, indices = torch.topk(confidence, k, dim=1)
indices = (
indices + torch.arange(bs, device=indices.device)[:, None] * N
).reshape(-1)
outputs = []
for input in inputs:
outputs.append(input.flatten(end_dim=1)[indices].reshape(bs, k, -1))
return confidence, outputs
@PLUGIN_LAYERS.register_module()
class InstanceBank(nn.Module):
def __init__(
self,
num_anchor,
embed_dims,
anchor,
anchor_handler=None,
num_temp_instances=0,
default_time_interval=0.5,
confidence_decay=0.6,
anchor_grad=True,
feat_grad=True,
max_time_interval=2,
):
super(InstanceBank, self).__init__()
self.embed_dims = embed_dims
self.num_temp_instances = num_temp_instances
self.default_time_interval = default_time_interval
self.confidence_decay = confidence_decay
self.max_time_interval = max_time_interval
if anchor_handler is not None:
anchor_handler = build_from_cfg(anchor_handler, PLUGIN_LAYERS)
assert hasattr(anchor_handler, "anchor_projection")
self.anchor_handler = anchor_handler
if isinstance(anchor, str):
anchor = np.load(anchor)
elif isinstance(anchor, (list, tuple)):
anchor = np.array(anchor)
if len(anchor.shape) == 3: # for map
anchor = anchor.reshape(anchor.shape[0], -1)
self.num_anchor = min(len(anchor), num_anchor)
anchor = anchor[:num_anchor]
self.anchor = nn.Parameter(
torch.tensor(anchor, dtype=torch.float32),
requires_grad=anchor_grad,
)
self.anchor_init = anchor
self.instance_feature = nn.Parameter(
torch.zeros([self.anchor.shape[0], self.embed_dims]),
requires_grad=feat_grad,
)
self.reset()
def init_weight(self):
self.anchor.data = self.anchor.data.new_tensor(self.anchor_init)
if self.instance_feature.requires_grad:
torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)
def reset(self):
self.cached_feature = None
self.cached_anchor = None
self.metas = None
self.mask = None
self.confidence = None
self.temp_confidence = None
self.instance_id = None
self.prev_id = 0
def get(self, batch_size, metas=None, dn_metas=None):
instance_feature = torch.tile(
self.instance_feature[None], (batch_size, 1, 1)
)
anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))
if (
self.cached_anchor is not None
and batch_size == self.cached_anchor.shape[0]
):
history_time = self.metas["timestamp"]
time_interval = metas["timestamp"] - history_time
time_interval = time_interval.to(dtype=instance_feature.dtype)
self.mask = torch.abs(time_interval) <= self.max_time_interval
if self.anchor_handler is not None:
T_temp2cur = self.cached_anchor.new_tensor(
np.stack(
[
x["T_global_inv"]
@ self.metas["img_metas"][i]["T_global"]
for i, x in enumerate(metas["img_metas"])
]
)
)
self.cached_anchor = self.anchor_handler.anchor_projection(
self.cached_anchor,
[T_temp2cur],
time_intervals=[-time_interval],
)[0]
if (
self.anchor_handler is not None
and dn_metas is not None
and batch_size == dn_metas["dn_anchor"].shape[0]
):
num_dn_group, num_dn = dn_metas["dn_anchor"].shape[1:3]
dn_anchor = self.anchor_handler.anchor_projection(
dn_metas["dn_anchor"].flatten(1, 2),
[T_temp2cur],
time_intervals=[-time_interval],
)[0]
dn_metas["dn_anchor"] = dn_anchor.reshape(
batch_size, num_dn_group, num_dn, -1
)
time_interval = torch.where(
torch.logical_and(time_interval != 0, self.mask),
time_interval,
time_interval.new_tensor(self.default_time_interval),
)
else:
self.reset()
time_interval = instance_feature.new_tensor(
[self.default_time_interval] * batch_size
)
return (
instance_feature,
anchor,
self.cached_feature,
self.cached_anchor,
time_interval,
)
def update(self, instance_feature, anchor, confidence):
if self.cached_feature is None:
return instance_feature, anchor
num_dn = 0
if instance_feature.shape[1] > self.num_anchor:
num_dn = instance_feature.shape[1] - self.num_anchor
dn_instance_feature = instance_feature[:, -num_dn:]
dn_anchor = anchor[:, -num_dn:]
instance_feature = instance_feature[:, : self.num_anchor]
anchor = anchor[:, : self.num_anchor]
confidence = confidence[:, : self.num_anchor]
N = self.num_anchor - self.num_temp_instances
confidence = confidence.max(dim=-1).values
_, (selected_feature, selected_anchor) = topk(
confidence, N, instance_feature, anchor
)
selected_feature = torch.cat(
[self.cached_feature, selected_feature], dim=1
)
selected_anchor = torch.cat(
[self.cached_anchor, selected_anchor], dim=1
)
instance_feature = torch.where(
self.mask[:, None, None], selected_feature, instance_feature
)
anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor)
self.confidence = torch.where(
self.mask[:, None],
self.confidence,
self.confidence.new_tensor(0)
)
if self.instance_id is not None:
self.instance_id = torch.where(
self.mask[:, None],
self.instance_id,
self.instance_id.new_tensor(-1),
)
if num_dn > 0:
instance_feature = torch.cat(
[instance_feature, dn_instance_feature], dim=1
)
anchor = torch.cat([anchor, dn_anchor], dim=1)
return instance_feature, anchor
def cache(
self,
instance_feature,
anchor,
confidence,
metas=None,
feature_maps=None,
):
if self.num_temp_instances <= 0:
return
instance_feature = instance_feature.detach()
anchor = anchor.detach()
confidence = confidence.detach()
self.metas = metas
confidence = confidence.max(dim=-1).values.sigmoid()
if self.confidence is not None:
confidence[:, : self.num_temp_instances] = torch.maximum(
self.confidence * self.confidence_decay,
confidence[:, : self.num_temp_instances],
)
self.temp_confidence = confidence
(
self.confidence,
(self.cached_feature, self.cached_anchor),
) = topk(confidence, self.num_temp_instances, instance_feature, anchor)
def get_instance_id(self, confidence, anchor=None, threshold=None):
confidence = confidence.max(dim=-1).values.sigmoid()
instance_id = confidence.new_full(confidence.shape, -1).long()
if (
self.instance_id is not None
and self.instance_id.shape[0] == instance_id.shape[0]
):
instance_id[:, : self.instance_id.shape[1]] = self.instance_id
mask = instance_id < 0
if threshold is not None:
mask = mask & (confidence >= threshold)
num_new_instance = mask.sum()
new_ids = torch.arange(num_new_instance).to(instance_id) + self.prev_id
instance_id[torch.where(mask)] = new_ids
self.prev_id += num_new_instance
self.update_instance_id(instance_id, confidence)
return instance_id
def update_instance_id(self, instance_id=None, confidence=None):
if self.temp_confidence is None:
if confidence.dim() == 3: # bs, num_anchor, num_cls
temp_conf = confidence.max(dim=-1).values
else: # bs, num_anchor
temp_conf = confidence
else:
temp_conf = self.temp_confidence
instance_id = topk(temp_conf, self.num_temp_instances, instance_id)[1][
0
]
instance_id = instance_id.squeeze(dim=-1)
self.instance_id = F.pad(
instance_id,
(0, self.num_anchor - self.num_temp_instances),
value=-1,
)
\ No newline at end of file
from .decoder import SparsePoint3DDecoder
from .target import SparsePoint3DTarget, HungarianLinesAssigner
from .match_cost import LinesL1Cost, MapQueriesCost
from .loss import LinesL1Loss, SparseLineLoss
from .map_blocks import (
SparsePoint3DRefinementModule,
SparsePoint3DKeyPointsGenerator,
SparsePoint3DEncoder,
)
\ No newline at end of file
from typing import Optional, List
import torch
from mmdet.core.bbox.builder import BBOX_CODERS
@BBOX_CODERS.register_module()
class SparsePoint3DDecoder(object):
def __init__(
self,
coords_dim: int = 2,
score_threshold: Optional[float] = None,
):
super(SparsePoint3DDecoder, self).__init__()
self.score_threshold = score_threshold
self.coords_dim = coords_dim
def decode(
self,
cls_scores,
pts_preds,
instance_id=None,
quality=None,
output_idx=-1,
):
bs, num_pred, num_cls = cls_scores[-1].shape
cls_scores = cls_scores[-1].sigmoid()
pts_preds = pts_preds[-1].reshape(bs, num_pred, -1, self.coords_dim)
cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
num_pred, dim=1
)
cls_ids = indices % num_cls
if self.score_threshold is not None:
mask = cls_scores >= self.score_threshold
output = []
for i in range(bs):
category_ids = cls_ids[i]
scores = cls_scores[i]
pts = pts_preds[i, indices[i] // num_cls]
if self.score_threshold is not None:
category_ids = category_ids[mask[i]]
scores = scores[mask[i]]
pts = pts[mask[i]]
output.append(
{
"vectors": [vec.detach().cpu().numpy() for vec in pts],
"scores": scores.detach().cpu().numpy(),
"labels": category_ids.detach().cpu().numpy(),
}
)
return output
\ No newline at end of file
import torch
import torch.nn as nn
from mmcv.utils import build_from_cfg
from mmdet.models.builder import LOSSES
from mmdet.models.losses import l1_loss, smooth_l1_loss
@LOSSES.register_module()
class LinesL1Loss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.beta = beta
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.beta > 0:
loss = smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
else:
loss = l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
num_points = pred.shape[-1] // 2
loss = loss / num_points
return loss*self.loss_weight
@LOSSES.register_module()
class SparseLineLoss(nn.Module):
def __init__(
self,
loss_line,
num_sample=20,
roi_size=(30, 60),
):
super().__init__()
def build(cfg, registry):
if cfg is None:
return None
return build_from_cfg(cfg, registry)
self.loss_line = build(loss_line, LOSSES)
self.num_sample = num_sample
self.roi_size = roi_size
def forward(
self,
line,
line_target,
weight=None,
avg_factor=None,
prefix="",
suffix="",
**kwargs,
):
output = {}
line = self.normalize_line(line)
line_target = self.normalize_line(line_target)
line_loss = self.loss_line(
line, line_target, weight=weight, avg_factor=avg_factor
)
output[f"{prefix}loss_line{suffix}"] = line_loss
return output
def normalize_line(self, line):
if line.shape[0] == 0:
return line
line = line.view(line.shape[:-1] + (self.num_sample, -1))
origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2])
line = line - origin
# transform from range [0, 1] to (0, 1)
eps = 1e-5
norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps
line = line / norm
line = line.flatten(-2, -1)
return line
from typing import Optional, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import Linear, Scale, bias_init_with_prob
from mmcv.runner.base_module import Sequential, BaseModule
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.registry import (
PLUGIN_LAYERS,
POSITIONAL_ENCODING,
)
from ..blocks import linear_relu_ln
@POSITIONAL_ENCODING.register_module()
class SparsePoint3DEncoder(BaseModule):
def __init__(
self,
embed_dims: int = 256,
num_sample: int = 20,
coords_dim: int = 2,
):
super(SparsePoint3DEncoder, self).__init__()
self.embed_dims = embed_dims
self.input_dims = num_sample * coords_dim
def embedding_layer(input_dims):
return nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, input_dims))
self.pos_fc = embedding_layer(self.input_dims)
def forward(self, anchor: torch.Tensor):
pos_feat = self.pos_fc(anchor)
return pos_feat
@PLUGIN_LAYERS.register_module()
class SparsePoint3DRefinementModule(BaseModule):
def __init__(
self,
embed_dims: int = 256,
num_sample: int = 20,
coords_dim: int = 2,
num_cls: int = 3,
with_cls_branch: bool = True,
):
super(SparsePoint3DRefinementModule, self).__init__()
self.embed_dims = embed_dims
self.num_sample = num_sample
self.output_dim = num_sample * coords_dim
self.num_cls = num_cls
self.layers = nn.Sequential(
*linear_relu_ln(embed_dims, 2, 2),
Linear(self.embed_dims, self.output_dim),
Scale([1.0] * self.output_dim),
)
self.with_cls_branch = with_cls_branch
if with_cls_branch:
self.cls_layers = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(self.embed_dims, self.num_cls),
)
def init_weight(self):
if self.with_cls_branch:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.cls_layers[-1].bias, bias_init)
def forward(
self,
instance_feature: torch.Tensor,
anchor: torch.Tensor,
anchor_embed: torch.Tensor,
time_interval: torch.Tensor = 1.0,
return_cls=True,
):
output = self.layers(instance_feature + anchor_embed)
output = output + anchor
if return_cls:
assert self.with_cls_branch, "Without classification layers !!!"
cls = self.cls_layers(instance_feature) ## NOTE anchor embed?
else:
cls = None
qt = None
return output, cls, qt
@PLUGIN_LAYERS.register_module()
class SparsePoint3DKeyPointsGenerator(BaseModule):
def __init__(
self,
embed_dims: int = 256,
num_sample: int = 20,
num_learnable_pts: int = 0,
fix_height: Tuple = (0,),
ground_height: int = 0,
):
super(SparsePoint3DKeyPointsGenerator, self).__init__()
self.embed_dims = embed_dims
self.num_sample = num_sample
self.num_learnable_pts = num_learnable_pts
self.num_pts = num_sample * len(fix_height) * num_learnable_pts
if self.num_learnable_pts > 0:
self.learnable_fc = Linear(self.embed_dims, self.num_pts * 2)
self.fix_height = np.array(fix_height)
self.ground_height = ground_height
def init_weight(self):
if self.num_learnable_pts > 0:
xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)
def forward(
self,
anchor,
instance_feature=None,
T_cur2temp_list=None,
cur_timestamp=None,
temp_timestamps=None,
):
assert self.num_learnable_pts > 0, 'No learnable pts'
bs, num_anchor, _ = anchor.shape
key_points = anchor.view(bs, num_anchor, self.num_sample, -1)
offset = (
self.learnable_fc(instance_feature)
.reshape(bs, num_anchor, self.num_sample, len(self.fix_height), self.num_learnable_pts, 2)
)
key_points = offset + key_points[..., None, None, :]
key_points = torch.cat(
[
key_points,
key_points.new_full(key_points.shape[:-1]+(1,), fill_value=self.ground_height),
],
dim=-1,
)
fix_height = key_points.new_tensor(self.fix_height)
height_offset = key_points.new_zeros([len(fix_height), 2])
height_offset = torch.cat([height_offset, fix_height[:,None]], dim=-1)
key_points = key_points + height_offset[None, None, None, :, None]
key_points = key_points.flatten(2, 4)
if (
cur_timestamp is None
or temp_timestamps is None
or T_cur2temp_list is None
or len(temp_timestamps) == 0
):
return key_points
temp_key_points_list = []
for i, t_time in enumerate(temp_timestamps):
temp_key_points = key_points
T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype)
temp_key_points = (
T_cur2temp[:, None, None, :3]
@ torch.cat(
[
temp_key_points,
torch.ones_like(temp_key_points[..., :1]),
],
dim=-1,
).unsqueeze(-1)
)
temp_key_points = temp_key_points.squeeze(-1)
temp_key_points_list.append(temp_key_points)
return key_points, temp_key_points_list
# @staticmethod
def anchor_projection(
self,
anchor,
T_src2dst_list,
src_timestamp=None,
dst_timestamps=None,
time_intervals=None,
):
dst_anchors = []
for i in range(len(T_src2dst_list)):
dst_anchor = anchor.clone()
bs, num_anchor, _ = anchor.shape
dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(1, 2)
T_src2dst = torch.unsqueeze(
T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
)
dst_anchor = (
torch.matmul(
T_src2dst[..., :2, :2], dst_anchor[..., None]
).squeeze(dim=-1)
+ T_src2dst[..., :2, 3]
)
dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(2, 3)
dst_anchors.append(dst_anchor)
return dst_anchors
\ No newline at end of file
import torch
from mmdet.core.bbox.match_costs.builder import MATCH_COST
from mmdet.core.bbox.match_costs import build_match_cost
from torch.nn.functional import smooth_l1_loss
@MATCH_COST.register_module()
class LinesL1Cost(object):
"""LinesL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.0, beta=0.0, permute=False):
self.weight = weight
self.permute = permute
self.beta = beta
def __call__(self, lines_pred, gt_lines, **kwargs):
"""
Args:
lines_pred (Tensor): predicted normalized lines:
[num_query, 2*num_points]
gt_lines (Tensor): Ground truth lines
[num_gt, 2*num_points] or [num_gt, num_permute, 2*num_points]
Returns:
torch.Tensor: reg_cost value with weight
shape [num_pred, num_gt]
"""
if self.permute:
assert len(gt_lines.shape) == 3
else:
assert len(gt_lines.shape) == 2
num_pred, num_gt = len(lines_pred), len(gt_lines)
if self.permute:
# permute-invarint labels
gt_lines = gt_lines.flatten(0, 1) # (num_gt*num_permute, 2*num_pts)
num_pts = lines_pred.shape[-1]//2
if self.beta > 0:
lines_pred = lines_pred.unsqueeze(1).repeat(1, len(gt_lines), 1)
gt_lines = gt_lines.unsqueeze(0).repeat(num_pred, 1, 1)
dist_mat = smooth_l1_loss(lines_pred, gt_lines, reduction='none', beta=self.beta).sum(-1)
else:
dist_mat = torch.cdist(lines_pred, gt_lines, p=1)
dist_mat = dist_mat / num_pts
if self.permute:
# dist_mat: (num_pred, num_gt*num_permute)
dist_mat = dist_mat.view(num_pred, num_gt, -1) # (num_pred, num_gt, num_permute)
dist_mat, gt_permute_index = torch.min(dist_mat, 2)
return dist_mat * self.weight, gt_permute_index
return dist_mat * self.weight
@MATCH_COST.register_module()
class MapQueriesCost(object):
def __init__(self, cls_cost, reg_cost, iou_cost=None):
self.cls_cost = build_match_cost(cls_cost)
self.reg_cost = build_match_cost(reg_cost)
self.iou_cost = None
if iou_cost is not None:
self.iou_cost = build_match_cost(iou_cost)
def __call__(self, preds: dict, gts: dict, ignore_cls_cost: bool):
# classification and bboxcost.
cls_cost = self.cls_cost(preds['scores'], gts['labels'])
# regression cost
regkwargs = {}
if 'masks' in preds and 'masks' in gts:
assert isinstance(self.reg_cost, DynamicLinesCost), ' Issues!!'
regkwargs = {
'masks_pred': preds['masks'],
'masks_gt': gts['masks'],
}
reg_cost = self.reg_cost(preds['lines'], gts['lines'], **regkwargs)
if self.reg_cost.permute:
reg_cost, gt_permute_idx = reg_cost
# weighted sum of above three costs
if ignore_cls_cost:
cost = reg_cost
else:
cost = cls_cost + reg_cost
# Iou
if self.iou_cost is not None:
iou_cost = self.iou_cost(preds['lines'],gts['lines'])
cost += iou_cost
if self.reg_cost.permute:
return cost, gt_permute_idx
return cost
import torch
import numpy as np
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from mmdet.core.bbox.builder import (BBOX_SAMPLERS, BBOX_ASSIGNERS)
from mmdet.core.bbox.match_costs import build_match_cost
from mmdet.core import (build_assigner, build_sampler)
from mmdet.core.bbox.assigners import (AssignResult, BaseAssigner)
from ..base_target import BaseTargetWithDenoising
@BBOX_SAMPLERS.register_module()
class SparsePoint3DTarget(BaseTargetWithDenoising):
def __init__(
self,
assigner=None,
num_dn_groups=0,
dn_noise_scale=0.5,
max_dn_gt=32,
add_neg_dn=True,
num_temp_dn_groups=0,
num_cls=3,
num_sample=20,
roi_size=(30, 60),
):
super(SparsePoint3DTarget, self).__init__(
num_dn_groups, num_temp_dn_groups
)
self.assigner = build_assigner(assigner)
self.dn_noise_scale = dn_noise_scale
self.max_dn_gt = max_dn_gt
self.add_neg_dn = add_neg_dn
self.num_cls = num_cls
self.num_sample = num_sample
self.roi_size = roi_size
def sample(
self,
cls_preds,
pts_preds,
cls_targets,
pts_targets,
):
pts_targets = [x.flatten(2, 3) if len(x.shape)==4 else x for x in pts_targets]
indices = []
for(cls_pred, pts_pred, cls_target, pts_target) in zip(
cls_preds, pts_preds, cls_targets, pts_targets
):
# normalize to (0, 1)
pts_pred = self.normalize_line(pts_pred)
pts_target = self.normalize_line(pts_target)
preds=dict(lines=pts_pred, scores=cls_pred)
gts=dict(lines=pts_target, labels=cls_target)
indice = self.assigner.assign(preds, gts)
indices.append(indice)
bs, num_pred, num_cls = cls_preds.shape
output_cls_target = cls_targets[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls
output_box_target = pts_preds.new_zeros(pts_preds.shape)
output_reg_weights = pts_preds.new_zeros(pts_preds.shape)
for i, (pred_idx, target_idx, gt_permute_index) in enumerate(indices):
if len(cls_targets[i]) == 0:
continue
permute_idx = gt_permute_index[pred_idx, target_idx]
output_cls_target[i, pred_idx] = cls_targets[i][target_idx]
output_box_target[i, pred_idx] = pts_targets[i][target_idx, permute_idx]
output_reg_weights[i, pred_idx] = 1
return output_cls_target, output_box_target, output_reg_weights
def normalize_line(self, line):
if line.shape[0] == 0:
return line
line = line.view(line.shape[:-1] + (self.num_sample, -1))
origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2])
line = line - origin
# transform from range [0, 1] to (0, 1)
eps = 1e-5
norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps
line = line / norm
line = line.flatten(-2, -1)
return line
@BBOX_ASSIGNERS.register_module()
class HungarianLinesAssigner(BaseAssigner):
"""
Computes one-to-one matching between predictions and ground truth.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost and regression L1 cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
bbox_weight (int | float, optional): The scale factor for regression
L1 cost. Default 1.0.
"""
def __init__(self, cost=dict, **kwargs):
self.cost = build_match_cost(cost)
def assign(self,
preds: dict,
gts: dict,
ignore_cls_cost=False,
gt_bboxes_ignore=None,
eps=1e-7):
"""
Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
Args:
lines_pred (Tensor): predicted normalized lines:
[num_query, num_points, 2]
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
lines_gt (Tensor): Ground truth lines
[num_gt, num_points, 2].
labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
assert gt_bboxes_ignore is None, \
'Only case when gt_bboxes_ignore is None is supported.'
num_gts, num_lines = gts['lines'].size(0), preds['lines'].size(0)
if num_gts == 0 or num_lines == 0:
return None, None, None
# compute the weighted costs
gt_permute_idx = None # (num_preds, num_gts)
if self.cost.reg_cost.permute:
cost, gt_permute_idx = self.cost(preds, gts, ignore_cls_cost)
else:
cost = self.cost(preds, gts, ignore_cls_cost)
# do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu().numpy()
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
return matched_row_inds, matched_col_inds, gt_permute_idx
\ No newline at end of file
from .motion_planning_head import MotionPlanningHead
from .motion_blocks import MotionPlanningRefinementModule
from .instance_queue import InstanceQueue
from .target import MotionTarget, PlanningTarget
from .decoder import SparseBox3DMotionDecoder, HierarchicalPlanningDecoder
from typing import Optional
import numpy as np
import torch
from mmdet.core.bbox.builder import BBOX_CODERS
from projects.mmdet3d_plugin.core.box3d import *
from projects.mmdet3d_plugin.models.detection3d.decoder import *
from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
@BBOX_CODERS.register_module()
class SparseBox3DMotionDecoder(SparseBox3DDecoder):
def __init__(self):
super(SparseBox3DMotionDecoder, self).__init__()
def decode(
self,
cls_scores,
box_preds,
instance_id=None,
quality=None,
motion_output=None,
output_idx=-1,
):
squeeze_cls = instance_id is not None
cls_scores = cls_scores[output_idx].sigmoid()
if squeeze_cls:
cls_scores, cls_ids = cls_scores.max(dim=-1)
cls_scores = cls_scores.unsqueeze(dim=-1)
box_preds = box_preds[output_idx]
bs, num_pred, num_cls = cls_scores.shape
cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
self.num_output, dim=1, sorted=self.sorted
)
if not squeeze_cls:
cls_ids = indices % num_cls
if self.score_threshold is not None:
mask = cls_scores >= self.score_threshold
if quality[output_idx] is None:
quality = None
if quality is not None:
centerness = quality[output_idx][..., CNS]
centerness = torch.gather(centerness, 1, indices // num_cls)
cls_scores_origin = cls_scores.clone()
cls_scores *= centerness.sigmoid()
cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
if not squeeze_cls:
cls_ids = torch.gather(cls_ids, 1, idx)
if self.score_threshold is not None:
mask = torch.gather(mask, 1, idx)
indices = torch.gather(indices, 1, idx)
output = []
anchor_queue = motion_output["anchor_queue"]
anchor_queue = torch.stack(anchor_queue, dim=2)
period = motion_output["period"]
for i in range(bs):
category_ids = cls_ids[i]
if squeeze_cls:
category_ids = category_ids[indices[i]]
scores = cls_scores[i]
box = box_preds[i, indices[i] // num_cls]
if self.score_threshold is not None:
category_ids = category_ids[mask[i]]
scores = scores[mask[i]]
box = box[mask[i]]
if quality is not None:
scores_origin = cls_scores_origin[i]
if self.score_threshold is not None:
scores_origin = scores_origin[mask[i]]
box = decode_box(box)
trajs = motion_output["prediction"][-1]
traj_cls = motion_output["classification"][-1].sigmoid()
traj = trajs[i, indices[i] // num_cls]
traj_cls = traj_cls[i, indices[i] // num_cls]
if self.score_threshold is not None:
traj = traj[mask[i]]
traj_cls = traj_cls[mask[i]]
traj = traj.cumsum(dim=-2) + box[:, None, None, :2]
output.append(
{
"trajs_3d": traj.cpu(),
"trajs_score": traj_cls.cpu()
}
)
temp_anchor = anchor_queue[i, indices[i] // num_cls]
temp_period = period[i, indices[i] // num_cls]
if self.score_threshold is not None:
temp_anchor = temp_anchor[mask[i]]
temp_period = temp_period[mask[i]]
num_pred, queue_len = temp_anchor.shape[:2]
temp_anchor = temp_anchor.flatten(0, 1)
temp_anchor = decode_box(temp_anchor)
temp_anchor = temp_anchor.reshape([num_pred, queue_len, box.shape[-1]])
output[-1]['anchor_queue'] = temp_anchor.cpu()
output[-1]['period'] = temp_period.cpu()
return output
@BBOX_CODERS.register_module()
class HierarchicalPlanningDecoder(object):
def __init__(
self,
ego_fut_ts,
ego_fut_mode,
use_rescore=False,
):
super(HierarchicalPlanningDecoder, self).__init__()
self.ego_fut_ts = ego_fut_ts
self.ego_fut_mode = ego_fut_mode
self.use_rescore = use_rescore
def decode(
self,
det_output,
motion_output,
planning_output,
data,
):
classification = planning_output['classification'][-1]
prediction = planning_output['prediction'][-1]
bs = classification.shape[0]
classification = classification.reshape(bs, 3, self.ego_fut_mode)
prediction = prediction.reshape(bs, 3, self.ego_fut_mode, self.ego_fut_ts, 2).cumsum(dim=-2)
classification, final_planning = self.select(det_output, motion_output, classification, prediction, data)
anchor_queue = planning_output["anchor_queue"]
anchor_queue = torch.stack(anchor_queue, dim=2)
period = planning_output["period"]
output = []
for i, (cls, pred) in enumerate(zip(classification, prediction)):
output.append(
{
"planning_score": cls.sigmoid().cpu(),
"planning": pred.cpu(),
"final_planning": final_planning[i].cpu(),
"ego_period": period[i].cpu(),
"ego_anchor_queue": decode_box(anchor_queue[i]).cpu(),
}
)
return output
def select(
self,
det_output,
motion_output,
plan_cls,
plan_reg,
data,
):
det_classification = det_output["classification"][-1].sigmoid()
det_anchors = det_output["prediction"][-1]
det_confidence = det_classification.max(dim=-1).values
motion_cls = motion_output["classification"][-1].sigmoid()
motion_reg = motion_output["prediction"][-1]
# cmd select
bs = motion_cls.shape[0]
bs_indices = torch.arange(bs, device=motion_cls.device)
cmd = data['gt_ego_fut_cmd'].argmax(dim=-1)
plan_cls_full = plan_cls.detach().clone()
plan_cls = plan_cls[bs_indices, cmd]
plan_reg = plan_reg[bs_indices, cmd]
# rescore
if self.use_rescore:
plan_cls = self.rescore(
plan_cls,
plan_reg,
motion_cls,
motion_reg,
det_anchors,
det_confidence,
)
plan_cls_full[bs_indices, cmd] = plan_cls
mode_idx = plan_cls.argmax(dim=-1)
final_planning = plan_reg[bs_indices, mode_idx]
return plan_cls_full, final_planning
def rescore(
self,
plan_cls,
plan_reg,
motion_cls,
motion_reg,
det_anchors,
det_confidence,
score_thresh=0.5,
static_dis_thresh=0.5,
dim_scale=1.1,
num_motion_mode=1,
offset=0.5,
):
def cat_with_zero(traj):
zeros = traj.new_zeros(traj.shape[:-2] + (1, 2))
traj_cat = torch.cat([zeros, traj], dim=-2)
return traj_cat
def get_yaw(traj, start_yaw=np.pi/2):
yaw = traj.new_zeros(traj.shape[:-1])
yaw[..., 1:-1] = torch.atan2(
traj[..., 2:, 1] - traj[..., :-2, 1],
traj[..., 2:, 0] - traj[..., :-2, 0],
)
yaw[..., -1] = torch.atan2(
traj[..., -1, 1] - traj[..., -2, 1],
traj[..., -1, 0] - traj[..., -2, 0],
)
yaw[..., 0] = start_yaw
# for static object, estimated future yaw would be unstable
start = traj[..., 0, :]
end = traj[..., -1, :]
dist = torch.linalg.norm(end - start, dim=-1)
mask = dist < static_dis_thresh
start_yaw = yaw[..., 0].unsqueeze(-1)
yaw = torch.where(
mask.unsqueeze(-1),
start_yaw,
yaw,
)
return yaw.unsqueeze(-1)
## ego
bs = plan_reg.shape[0]
plan_reg_cat = cat_with_zero(plan_reg)
ego_box = det_anchors.new_zeros(bs, self.ego_fut_mode, self.ego_fut_ts + 1, 7)
ego_box[..., [X, Y]] = plan_reg_cat
ego_box[..., [W, L, H]] = ego_box.new_tensor([4.08, 1.73, 1.56]) * dim_scale
ego_box[..., [YAW]] = get_yaw(plan_reg_cat)
## motion
motion_reg = motion_reg[..., :self.ego_fut_ts, :].cumsum(-2)
motion_reg = cat_with_zero(motion_reg) + det_anchors[:, :, None, None, :2]
_, motion_mode_idx = torch.topk(motion_cls, num_motion_mode, dim=-1)
motion_mode_idx = motion_mode_idx[..., None, None].repeat(1, 1, 1, self.ego_fut_ts + 1, 2)
motion_reg = torch.gather(motion_reg, 2, motion_mode_idx)
motion_box = motion_reg.new_zeros(motion_reg.shape[:-1] + (7,))
motion_box[..., [X, Y]] = motion_reg
motion_box[..., [W, L, H]] = det_anchors[..., None, None, [W, L, H]].exp()
box_yaw = torch.atan2(
det_anchors[..., SIN_YAW],
det_anchors[..., COS_YAW],
)
motion_box[..., [YAW]] = get_yaw(motion_reg, box_yaw.unsqueeze(-1))
filter_mask = det_confidence < score_thresh
motion_box[filter_mask] = 1e6
ego_box = ego_box[..., 1:, :]
motion_box = motion_box[..., 1:, :]
bs, num_ego_mode, ts, _ = ego_box.shape
bs, num_anchor, num_motion_mode, ts, _ = motion_box.shape
ego_box = ego_box[:, None, None].repeat(1, num_anchor, num_motion_mode, 1, 1, 1).flatten(0, -2)
motion_box = motion_box.unsqueeze(3).repeat(1, 1, 1, num_ego_mode, 1, 1).flatten(0, -2)
ego_box[0] += offset * torch.cos(ego_box[6])
ego_box[1] += offset * torch.sin(ego_box[6])
col = check_collision(ego_box, motion_box)
col = col.reshape(bs, num_anchor, num_motion_mode, num_ego_mode, ts).permute(0, 3, 1, 2, 4)
col = col.flatten(2, -1).any(dim=-1)
all_col = col.all(dim=-1)
col[all_col] = False # for case that all modes collide, no need to rescore
score_offset = col.float() * -999
plan_cls = plan_cls + score_offset
return plan_cls
def check_collision(boxes1, boxes2):
'''
A rough check for collision detection:
check if any corner point of boxes1 is inside boxes2 and vice versa.
boxes1: tensor with shape [N, 7], [x, y, z, w, l, h, yaw]
boxes2: tensor with shape [N, 7]
'''
col_1 = corners_in_box(boxes1.clone(), boxes2.clone())
col_2 = corners_in_box(boxes2.clone(), boxes1.clone())
collision = torch.logical_or(col_1, col_2)
return collision
def corners_in_box(boxes1, boxes2):
if boxes1.shape[0] == 0 or boxes2.shape[0] == 0:
return False
boxes1_yaw = boxes1[:, 6].clone()
boxes1_loc = boxes1[:, :3].clone()
cos_yaw = torch.cos(-boxes1_yaw)
sin_yaw = torch.sin(-boxes1_yaw)
rot_mat_T = torch.stack(
[
torch.stack([cos_yaw, sin_yaw]),
torch.stack([-sin_yaw, cos_yaw]),
]
)
# translate and rotate boxes
boxes1[:, :3] = boxes1[:, :3] - boxes1_loc
boxes1[:, :2] = torch.einsum('ij,jki->ik', boxes1[:, :2], rot_mat_T)
boxes1[:, 6] = boxes1[:, 6] - boxes1_yaw
boxes2[:, :3] = boxes2[:, :3] - boxes1_loc
boxes2[:, :2] = torch.einsum('ij,jki->ik', boxes2[:, :2], rot_mat_T)
boxes2[:, 6] = boxes2[:, 6] - boxes1_yaw
corners_box2 = box3d_to_corners(boxes2)[:, [0, 3, 7, 4], :2]
corners_box2 = torch.from_numpy(corners_box2).to(boxes2.device)
H = boxes1[:, [3]]
W = boxes1[:, [4]]
collision = torch.logical_and(
torch.logical_and(corners_box2[..., 0] <= H / 2, corners_box2[..., 0] >= -H / 2),
torch.logical_and(corners_box2[..., 1] <= W / 2, corners_box2[..., 1] >= -W / 2),
)
collision = collision.any(dim=-1)
return collision
\ No newline at end of file
import copy
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from projects.mmdet3d_plugin.ops import feature_maps_format
from projects.mmdet3d_plugin.core.box3d import *
@PLUGIN_LAYERS.register_module()
class InstanceQueue(nn.Module):
def __init__(
self,
embed_dims,
queue_length=0,
tracking_threshold=0,
feature_map_scale=None,
):
super(InstanceQueue, self).__init__()
self.embed_dims = embed_dims
self.queue_length = queue_length
self.tracking_threshold = tracking_threshold
kernel_size = tuple([int(x / 2) for x in feature_map_scale])
self.ego_feature_encoder = nn.Sequential(
nn.Conv2d(embed_dims, embed_dims, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(embed_dims),
nn.Conv2d(embed_dims, embed_dims, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(embed_dims),
nn.ReLU(),
nn.AvgPool2d(kernel_size),
)
self.ego_anchor = nn.Parameter(
torch.tensor([[0, 0.5, -1.84 + 1.56/2, np.log(4.08), np.log(1.73), np.log(1.56), 1, 0, 0, 0, 0],], dtype=torch.float32),
requires_grad=False,
)
self.reset()
def reset(self):
self.metas = None
self.prev_instance_id = None
self.prev_confidence = None
self.period = None
self.instance_feature_queue = []
self.anchor_queue = []
self.prev_ego_status = None
self.ego_period = None
self.ego_feature_queue = []
self.ego_anchor_queue = []
def get(
self,
det_output,
feature_maps,
metas,
batch_size,
mask,
anchor_handler,
):
if (
self.period is not None
and batch_size == self.period.shape[0]
):
if anchor_handler is not None:
T_temp2cur = feature_maps[0].new_tensor(
np.stack(
[
x["T_global_inv"]
@ self.metas["img_metas"][i]["T_global"]
for i, x in enumerate(metas["img_metas"])
]
)
)
for i in range(len(self.anchor_queue)):
temp_anchor = self.anchor_queue[i]
temp_anchor = anchor_handler.anchor_projection(
temp_anchor,
[T_temp2cur],
)[0]
self.anchor_queue[i] = temp_anchor
for i in range(len(self.ego_anchor_queue)):
temp_anchor = self.ego_anchor_queue[i]
temp_anchor = anchor_handler.anchor_projection(
temp_anchor,
[T_temp2cur],
)[0]
self.ego_anchor_queue[i] = temp_anchor
else:
self.reset()
self.prepare_motion(det_output, mask)
ego_feature, ego_anchor = self.prepare_planning(feature_maps, mask, batch_size)
# temporal
temp_instance_feature = torch.stack(self.instance_feature_queue, dim=2)
temp_anchor = torch.stack(self.anchor_queue, dim=2)
temp_ego_feature = torch.stack(self.ego_feature_queue, dim=2)
temp_ego_anchor = torch.stack(self.ego_anchor_queue, dim=2)
period = torch.cat([self.period, self.ego_period], dim=1)
temp_instance_feature = torch.cat([temp_instance_feature, temp_ego_feature], dim=1)
temp_anchor = torch.cat([temp_anchor, temp_ego_anchor], dim=1)
num_agent = temp_anchor.shape[1]
temp_mask = torch.arange(len(self.anchor_queue), 0, -1, device=temp_anchor.device)
temp_mask = temp_mask[None, None].repeat((batch_size, num_agent, 1))
temp_mask = torch.gt(temp_mask, period[..., None])
return ego_feature, ego_anchor, temp_instance_feature, temp_anchor, temp_mask
def prepare_motion(
self,
det_output,
mask,
):
instance_feature = det_output["instance_feature"]
det_anchors = det_output["prediction"][-1]
if self.period == None:
self.period = instance_feature.new_zeros(instance_feature.shape[:2]).long()
else:
instance_id = det_output['instance_id']
prev_instance_id = self.prev_instance_id
match = instance_id[..., None] == prev_instance_id[:, None]
if self.tracking_threshold > 0:
temp_mask = self.prev_confidence > self.tracking_threshold
match = match * temp_mask.unsqueeze(1)
for i in range(len(self.instance_feature_queue)):
temp_feature = self.instance_feature_queue[i]
temp_feature = (
match[..., None] * temp_feature[:, None]
).sum(dim=2)
self.instance_feature_queue[i] = temp_feature
temp_anchor = self.anchor_queue[i]
temp_anchor = (
match[..., None] * temp_anchor[:, None]
).sum(dim=2)
self.anchor_queue[i] = temp_anchor
self.period = (
match * self.period[:, None]
).sum(dim=2)
self.instance_feature_queue.append(instance_feature.detach())
self.anchor_queue.append(det_anchors.detach())
self.period += 1
if len(self.instance_feature_queue) > self.queue_length:
self.instance_feature_queue.pop(0)
self.anchor_queue.pop(0)
self.period = torch.clip(self.period, 0, self.queue_length)
def prepare_planning(
self,
feature_maps,
mask,
batch_size,
):
## ego instance init
feature_maps_inv = feature_maps_format(feature_maps, inverse=True)
feature_map = feature_maps_inv[0][-1][:, 0]
ego_feature = self.ego_feature_encoder(feature_map)
ego_feature = ego_feature.unsqueeze(1).squeeze(-1).squeeze(-1)
ego_anchor = torch.tile(
self.ego_anchor[None], (batch_size, 1, 1)
)
if self.prev_ego_status is not None:
prev_ego_status = torch.where(
mask[:, None, None],
self.prev_ego_status,
self.prev_ego_status.new_tensor(0),
)
ego_anchor[..., VY] = prev_ego_status[..., 6]
if self.ego_period == None:
self.ego_period = ego_feature.new_zeros((batch_size, 1)).long()
else:
self.ego_period = torch.where(
mask[:, None],
self.ego_period,
self.ego_period.new_tensor(0),
)
self.ego_feature_queue.append(ego_feature.detach())
self.ego_anchor_queue.append(ego_anchor.detach())
self.ego_period += 1
if len(self.ego_feature_queue) > self.queue_length:
self.ego_feature_queue.pop(0)
self.ego_anchor_queue.pop(0)
self.ego_period = torch.clip(self.ego_period, 0, self.queue_length)
return ego_feature, ego_anchor
def cache_motion(self, instance_feature, det_output, metas):
det_classification = det_output["classification"][-1].sigmoid()
det_confidence = det_classification.max(dim=-1).values
instance_id = det_output['instance_id']
self.metas = metas
self.prev_confidence = det_confidence.detach()
self.prev_instance_id = instance_id
def cache_planning(self, ego_feature, ego_status):
self.prev_ego_status = ego_status.detach()
self.ego_feature_queue[-1] = ego_feature.detach()
import torch
import torch.nn as nn
import numpy as np
from mmcv.cnn import Linear, Scale, bias_init_with_prob
from mmcv.runner.base_module import Sequential, BaseModule
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.registry import (
PLUGIN_LAYERS,
)
from projects.mmdet3d_plugin.core.box3d import *
from ..blocks import linear_relu_ln
@PLUGIN_LAYERS.register_module()
class MotionPlanningRefinementModule(BaseModule):
def __init__(
self,
embed_dims=256,
fut_ts=12,
fut_mode=6,
ego_fut_ts=6,
ego_fut_mode=3,
):
super(MotionPlanningRefinementModule, self).__init__()
self.embed_dims = embed_dims
self.fut_ts = fut_ts
self.fut_mode = fut_mode
self.ego_fut_ts = ego_fut_ts
self.ego_fut_mode = ego_fut_mode
self.motion_cls_branch = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(embed_dims, 1),
)
self.motion_reg_branch = nn.Sequential(
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, fut_ts * 2),
)
self.plan_cls_branch = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
Linear(embed_dims, 1),
)
self.plan_reg_branch = nn.Sequential(
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, ego_fut_ts * 2),
)
self.plan_status_branch = nn.Sequential(
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, 10),
)
def init_weight(self):
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.motion_cls_branch[-1].bias, bias_init)
nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init)
def forward(
self,
motion_query,
plan_query,
ego_feature,
ego_anchor_embed,
):
bs, num_anchor = motion_query.shape[:2]
motion_cls = self.motion_cls_branch(motion_query).squeeze(-1)
motion_reg = self.motion_reg_branch(motion_query).reshape(bs, num_anchor, self.fut_mode, self.fut_ts, 2)
plan_cls = self.plan_cls_branch(plan_query).squeeze(-1)
plan_reg = self.plan_reg_branch(plan_query).reshape(bs, 1, 3 * self.ego_fut_mode, self.ego_fut_ts, 2)
planning_status = self.plan_status_branch(ego_feature + ego_anchor_embed)
return motion_cls, motion_reg, plan_cls, plan_reg, planning_status
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment