Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Autonomous-Driving-models
Commits
19472568
Commit
19472568
authored
Apr 08, 2026
by
雍大凯
Browse files
将子模块转换为普通目录
parent
51e55208
Changes
233
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6025 additions
and
0 deletions
+6025
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/dense_heads/__init__.py
...pTR/projects/mmdet3d_plugin/maptr/dense_heads/__init__.py
+2
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/dense_heads/maptr_head.py
...R/projects/mmdet3d_plugin/maptr/dense_heads/maptr_head.py
+767
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/dense_heads/maptrv2_head.py
...projects/mmdet3d_plugin/maptr/dense_heads/maptrv2_head.py
+1014
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/detectors/__init__.py
...MapTR/projects/mmdet3d_plugin/maptr/detectors/__init__.py
+2
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/detectors/maptr.py
...v2/MapTR/projects/mmdet3d_plugin/maptr/detectors/maptr.py
+442
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/detectors/maptrv2.py
.../MapTR/projects/mmdet3d_plugin/maptr/detectors/maptrv2.py
+411
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/losses/__init__.py
...v2/MapTR/projects/mmdet3d_plugin/maptr/losses/__init__.py
+7
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/losses/map_loss.py
...v2/MapTR/projects/mmdet3d_plugin/maptr/losses/map_loss.py
+719
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/losses/simple_loss.py
...MapTR/projects/mmdet3d_plugin/maptr/losses/simple_loss.py
+115
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/__init__.py
...2/MapTR/projects/mmdet3d_plugin/maptr/modules/__init__.py
+5
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/builder.py
...v2/MapTR/projects/mmdet3d_plugin/maptr/modules/builder.py
+5
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/decoder.py
...v2/MapTR/projects/mmdet3d_plugin/maptr/modules/decoder.py
+266
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/encoder.py
...v2/MapTR/projects/mmdet3d_plugin/maptr/modules/encoder.py
+1479
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/geometry_kernel_attention.py
...mmdet3d_plugin/maptr/modules/geometry_kernel_attention.py
+506
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/__init__.py
...lugin/maptr/modules/ops/geometric_kernel_attn/__init__.py
+1
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/function/__init__.py
...tr/modules/ops/geometric_kernel_attn/function/__init__.py
+1
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/function/geometric_kernel_attn_func.py
...metric_kernel_attn/function/geometric_kernel_attn_func.py
+31
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/setup.py
...d_plugin/maptr/modules/ops/geometric_kernel_attn/setup.py
+66
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/src/geometric_kernel_attn.h
...les/ops/geometric_kernel_attn/src/geometric_kernel_attn.h
+42
-0
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/src/geometric_kernel_attn_cuda.cu
...s/geometric_kernel_attn/src/geometric_kernel_attn_cuda.cu
+144
-0
No files found.
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/dense_heads/__init__.py
0 → 100644
View file @
19472568
from
.maptr_head
import
MapTRHead
from
.maptrv2_head
import
MapTRv2Head
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/dense_heads/maptr_head.py
0 → 100644
View file @
19472568
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models
import
HEADS
,
build_loss
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet3d.core.bbox.coders
import
build_bbox_coder
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
xavier_init
,
constant_init
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.core.bbox.transforms
import
bbox_xyxy_to_cxcywh
,
bbox_cxcywh_to_xyxy
from
mmdet.core
import
(
multi_apply
,
multi_apply
,
reduce_mean
)
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
def
normalize_2d_bbox
(
bboxes
,
pc_range
):
patch_h
=
pc_range
[
4
]
-
pc_range
[
1
]
patch_w
=
pc_range
[
3
]
-
pc_range
[
0
]
cxcywh_bboxes
=
bbox_xyxy_to_cxcywh
(
bboxes
)
cxcywh_bboxes
[...,
0
:
1
]
=
cxcywh_bboxes
[...,
0
:
1
]
-
pc_range
[
0
]
cxcywh_bboxes
[...,
1
:
2
]
=
cxcywh_bboxes
[...,
1
:
2
]
-
pc_range
[
1
]
factor
=
bboxes
.
new_tensor
([
patch_w
,
patch_h
,
patch_w
,
patch_h
])
normalized_bboxes
=
cxcywh_bboxes
/
factor
return
normalized_bboxes
def
normalize_2d_pts
(
pts
,
pc_range
):
patch_h
=
pc_range
[
4
]
-
pc_range
[
1
]
patch_w
=
pc_range
[
3
]
-
pc_range
[
0
]
new_pts
=
pts
.
clone
()
new_pts
[...,
0
:
1
]
=
pts
[...,
0
:
1
]
-
pc_range
[
0
]
new_pts
[...,
1
:
2
]
=
pts
[...,
1
:
2
]
-
pc_range
[
1
]
factor
=
pts
.
new_tensor
([
patch_w
,
patch_h
])
normalized_pts
=
new_pts
/
factor
return
normalized_pts
def
denormalize_2d_bbox
(
bboxes
,
pc_range
):
bboxes
=
bbox_cxcywh_to_xyxy
(
bboxes
)
bboxes
[...,
0
::
2
]
=
(
bboxes
[...,
0
::
2
]
*
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
])
bboxes
[...,
1
::
2
]
=
(
bboxes
[...,
1
::
2
]
*
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
])
return
bboxes
def
denormalize_2d_pts
(
pts
,
pc_range
):
new_pts
=
pts
.
clone
()
new_pts
[...,
0
:
1
]
=
(
pts
[...,
0
:
1
]
*
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
])
new_pts
[...,
1
:
2
]
=
(
pts
[...,
1
:
2
]
*
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
])
return
new_pts
@
HEADS
.
register_module
()
class
MapTRHead
(
DETRHead
):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def
__init__
(
self
,
*
args
,
with_box_refine
=
False
,
as_two_stage
=
False
,
transformer
=
None
,
bbox_coder
=
None
,
num_cls_fcs
=
2
,
code_weights
=
None
,
bev_h
=
30
,
bev_w
=
30
,
num_vec
=
20
,
num_pts_per_vec
=
2
,
num_pts_per_gt_vec
=
2
,
query_embed_type
=
'all_pts'
,
transform_method
=
'minmax'
,
gt_shift_pts_pattern
=
'v0'
,
dir_interval
=
1
,
loss_pts
=
dict
(
type
=
'ChamferDistance'
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
),
loss_dir
=
dict
(
type
=
'PtsDirCosLoss'
,
loss_weight
=
2.0
),
**
kwargs
):
self
.
bev_h
=
bev_h
self
.
bev_w
=
bev_w
self
.
fp16_enabled
=
False
self
.
with_box_refine
=
with_box_refine
self
.
as_two_stage
=
as_two_stage
self
.
bev_encoder_type
=
transformer
.
encoder
.
type
if
self
.
as_two_stage
:
transformer
[
'as_two_stage'
]
=
self
.
as_two_stage
if
'code_size'
in
kwargs
:
self
.
code_size
=
kwargs
[
'code_size'
]
else
:
self
.
code_size
=
10
if
code_weights
is
not
None
:
self
.
code_weights
=
code_weights
else
:
self
.
code_weights
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
0.2
,
0.2
]
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
pc_range
=
self
.
bbox_coder
.
pc_range
self
.
real_w
=
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
]
self
.
real_h
=
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
]
self
.
num_cls_fcs
=
num_cls_fcs
-
1
self
.
query_embed_type
=
query_embed_type
self
.
transform_method
=
transform_method
self
.
gt_shift_pts_pattern
=
gt_shift_pts_pattern
num_query
=
num_vec
*
num_pts_per_vec
self
.
num_query
=
num_query
self
.
num_vec
=
num_vec
self
.
num_pts_per_vec
=
num_pts_per_vec
self
.
num_pts_per_gt_vec
=
num_pts_per_gt_vec
self
.
dir_interval
=
dir_interval
super
(
MapTRHead
,
self
).
__init__
(
*
args
,
transformer
=
transformer
,
**
kwargs
)
self
.
code_weights
=
nn
.
Parameter
(
torch
.
tensor
(
self
.
code_weights
,
requires_grad
=
False
),
requires_grad
=
False
)
self
.
loss_pts
=
build_loss
(
loss_pts
)
self
.
loss_dir
=
build_loss
(
loss_dir
)
num_query
=
num_vec
*
num_pts_per_vec
self
.
num_query
=
num_query
self
.
num_vec
=
num_vec
self
.
num_pts_per_vec
=
num_pts_per_vec
self
.
num_pts_per_gt_vec
=
num_pts_per_gt_vec
self
.
_init_layers
()
def
_init_layers
(
self
):
"""Initialize classification branch and regression branch of head."""
cls_branch
=
[]
# cls_branch.append(Linear(self.embed_dims * 2, self.embed_dims))
# cls_branch.append(nn.LayerNorm(self.embed_dims))
# cls_branch.append(nn.ReLU(inplace=True))
for
_
in
range
(
self
.
num_reg_fcs
):
cls_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
cls_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
cls_branch
.
append
(
nn
.
ReLU
(
inplace
=
True
))
cls_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
cls_out_channels
))
fc_cls
=
nn
.
Sequential
(
*
cls_branch
)
reg_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
ReLU
())
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
code_size
))
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
def
_get_clones
(
module
,
N
):
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred
=
(
self
.
transformer
.
decoder
.
num_layers
+
1
)
if
\
self
.
as_two_stage
else
self
.
transformer
.
decoder
.
num_layers
if
self
.
with_box_refine
:
self
.
cls_branches
=
_get_clones
(
fc_cls
,
num_pred
)
self
.
reg_branches
=
_get_clones
(
reg_branch
,
num_pred
)
else
:
self
.
cls_branches
=
nn
.
ModuleList
(
[
fc_cls
for
_
in
range
(
num_pred
)])
self
.
reg_branches
=
nn
.
ModuleList
(
[
reg_branch
for
_
in
range
(
num_pred
)])
if
not
self
.
as_two_stage
:
if
self
.
bev_encoder_type
==
'BEVFormerEncoder'
:
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
else
:
self
.
bev_embedding
=
None
if
self
.
query_embed_type
==
'all_pts'
:
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
*
2
)
elif
self
.
query_embed_type
==
'instance_pts'
:
self
.
query_embedding
=
None
self
.
instance_embedding
=
nn
.
Embedding
(
self
.
num_vec
,
self
.
embed_dims
*
2
)
self
.
pts_embedding
=
nn
.
Embedding
(
self
.
num_pts_per_vec
,
self
.
embed_dims
*
2
)
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
self
.
transformer
.
init_weights
()
if
self
.
loss_cls
.
use_sigmoid
:
bias_init
=
bias_init_with_prob
(
0.01
)
for
m
in
self
.
cls_branches
:
nn
.
init
.
constant_
(
m
[
-
1
].
bias
,
bias_init
)
# for m in self.reg_branches:
# constant_init(m[-1], 0, bias=0)
# nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], 0.)
# @auto_fp16(apply_to=('mlvl_feats'))
@
force_fp32
(
apply_to
=
(
'mlvl_feats'
,
'prev_bev'
))
def
forward
(
self
,
mlvl_feats
,
lidar_feat
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
\
shape [nb_dec, bs, num_query, cls_out_channels]. Note
\
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression
\
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy).
\
Shape [nb_dec, bs, num_query, 9].
"""
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
# import pdb;pdb.set_trace()
if
self
.
query_embed_type
==
'all_pts'
:
object_query_embeds
=
self
.
query_embedding
.
weight
.
to
(
dtype
)
elif
self
.
query_embed_type
==
'instance_pts'
:
pts_embeds
=
self
.
pts_embedding
.
weight
.
unsqueeze
(
0
)
instance_embeds
=
self
.
instance_embedding
.
weight
.
unsqueeze
(
1
)
object_query_embeds
=
(
pts_embeds
+
instance_embeds
).
flatten
(
0
,
1
).
to
(
dtype
)
if
self
.
bev_embedding
is
not
None
:
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
else
:
bev_queries
=
None
bev_mask
=
None
bev_pos
=
None
if
only_bev
:
# only use encoder to obtain BEV features, TODO: refine the workaround
return
self
.
transformer
.
get_bev_features
(
mlvl_feats
,
lidar_feat
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)
else
:
outputs
=
self
.
transformer
(
mlvl_feats
,
lidar_feat
,
bev_queries
,
object_query_embeds
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
reg_branches
=
self
.
reg_branches
if
self
.
with_box_refine
else
None
,
# noqa:E501
cls_branches
=
self
.
cls_branches
if
self
.
as_two_stage
else
None
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
)
bev_embed
,
hs
,
init_reference
,
inter_references
=
outputs
hs
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs_classes
=
[]
outputs_coords
=
[]
outputs_pts_coords
=
[]
for
lvl
in
range
(
hs
.
shape
[
0
]):
if
lvl
==
0
:
# import pdb;pdb.set_trace()
reference
=
init_reference
else
:
reference
=
inter_references
[
lvl
-
1
]
reference
=
inverse_sigmoid
(
reference
)
# import pdb;pdb.set_trace()
# vec_embedding = hs[lvl].reshape(bs, self.num_vec, -1)
outputs_class
=
self
.
cls_branches
[
lvl
](
hs
[
lvl
]
.
view
(
bs
,
self
.
num_vec
,
self
.
num_pts_per_vec
,
-
1
)
.
mean
(
2
))
tmp
=
self
.
reg_branches
[
lvl
](
hs
[
lvl
])
# TODO: check the shape of reference
assert
reference
.
shape
[
-
1
]
==
2
tmp
[...,
0
:
2
]
+=
reference
[...,
0
:
2
]
# tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
tmp
=
tmp
.
sigmoid
()
# cx,cy,w,h
# import pdb;pdb.set_trace()
# tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] -
# self.pc_range[0]) + self.pc_range[0])
# tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] -
# self.pc_range[1]) + self.pc_range[1])
# tmp = tmp.reshape(bs, self.num_vec,-1)
# TODO: check if using sigmoid
outputs_coord
,
outputs_pts_coord
=
self
.
transform_box
(
tmp
)
outputs_classes
.
append
(
outputs_class
)
outputs_coords
.
append
(
outputs_coord
)
outputs_pts_coords
.
append
(
outputs_pts_coord
)
outputs_classes
=
torch
.
stack
(
outputs_classes
)
outputs_coords
=
torch
.
stack
(
outputs_coords
)
outputs_pts_coords
=
torch
.
stack
(
outputs_pts_coords
)
outs
=
{
'bev_embed'
:
bev_embed
,
'all_cls_scores'
:
outputs_classes
,
'all_bbox_preds'
:
outputs_coords
,
'all_pts_preds'
:
outputs_pts_coords
,
'enc_cls_scores'
:
None
,
'enc_bbox_preds'
:
None
,
'enc_pts_preds'
:
None
}
return
outs
def
transform_box
(
self
,
pts
,
y_first
=
False
):
"""
Converting the points set into bounding box.
Args:
pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
Returns:
The bbox [cx, cy, w, h] transformed from points.
"""
pts_reshape
=
pts
.
view
(
pts
.
shape
[
0
],
self
.
num_vec
,
self
.
num_pts_per_vec
,
2
)
pts_y
=
pts_reshape
[:,
:,
:,
0
]
if
y_first
else
pts_reshape
[:,
:,
:,
1
]
pts_x
=
pts_reshape
[:,
:,
:,
1
]
if
y_first
else
pts_reshape
[:,
:,
:,
0
]
if
self
.
transform_method
==
'minmax'
:
# import pdb;pdb.set_trace()
xmin
=
pts_x
.
min
(
dim
=
2
,
keepdim
=
True
)[
0
]
xmax
=
pts_x
.
max
(
dim
=
2
,
keepdim
=
True
)[
0
]
ymin
=
pts_y
.
min
(
dim
=
2
,
keepdim
=
True
)[
0
]
ymax
=
pts_y
.
max
(
dim
=
2
,
keepdim
=
True
)[
0
]
bbox
=
torch
.
cat
([
xmin
,
ymin
,
xmax
,
ymax
],
dim
=
2
)
bbox
=
bbox_xyxy_to_cxcywh
(
bbox
)
else
:
raise
NotImplementedError
return
bbox
,
pts_reshape
def
_get_target_single
(
self
,
cls_score
,
bbox_pred
,
pts_pred
,
gt_labels
,
gt_bboxes
,
gt_shifts_pts
,
gt_bboxes_ignore
=
None
):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
# import pdb;pdb.set_trace()
num_bboxes
=
bbox_pred
.
size
(
0
)
# assigner and sampler
gt_c
=
gt_bboxes
.
shape
[
-
1
]
# import pdb;pdb.set_trace()
assign_result
,
order_index
=
self
.
assigner
.
assign
(
bbox_pred
,
cls_score
,
pts_pred
,
gt_bboxes
,
gt_labels
,
gt_shifts_pts
,
gt_bboxes_ignore
)
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
bbox_pred
,
gt_bboxes
)
# pts_sampling_result = self.sampler.sample(assign_result, pts_pred,
# gt_pts)
# import pdb;pdb.set_trace()
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
# label targets
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,),
self
.
num_classes
,
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
label_weights
=
gt_bboxes
.
new_ones
(
num_bboxes
)
# bbox targets
bbox_targets
=
torch
.
zeros_like
(
bbox_pred
)[...,
:
gt_c
]
bbox_weights
=
torch
.
zeros_like
(
bbox_pred
)
bbox_weights
[
pos_inds
]
=
1.0
# pts targets
# import pdb;pdb.set_trace()
# pts_targets = torch.zeros_like(pts_pred)
# num_query, num_order, num_points, num_coords
if
order_index
is
None
:
# import pdb;pdb.set_trace()
assigned_shift
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
else
:
assigned_shift
=
order_index
[
sampling_result
.
pos_inds
,
sampling_result
.
pos_assigned_gt_inds
]
pts_targets
=
pts_pred
.
new_zeros
((
pts_pred
.
size
(
0
),
pts_pred
.
size
(
1
),
pts_pred
.
size
(
2
)))
pts_weights
=
torch
.
zeros_like
(
pts_targets
)
pts_weights
[
pos_inds
]
=
1.0
# DETR
bbox_targets
[
pos_inds
]
=
sampling_result
.
pos_gt_bboxes
pts_targets
[
pos_inds
]
=
gt_shifts_pts
[
sampling_result
.
pos_assigned_gt_inds
,
assigned_shift
,:,:]
return
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pts_targets
,
pts_weights
,
pos_inds
,
neg_inds
)
def
get_targets
(
self
,
cls_scores_list
,
bbox_preds_list
,
pts_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
=
None
):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all
\
images.
- bbox_targets_list (list[Tensor]): BBox targets for all
\
images.
- bbox_weights_list (list[Tensor]): BBox weights for all
\
images.
- num_total_pos (int): Number of positive samples in all
\
images.
- num_total_neg (int): Number of negative samples in all
\
images.
"""
assert
gt_bboxes_ignore_list
is
None
,
\
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs
=
len
(
cls_scores_list
)
gt_bboxes_ignore_list
=
[
gt_bboxes_ignore_list
for
_
in
range
(
num_imgs
)
]
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pts_targets_list
,
pts_weights_list
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
pts_preds_list
,
gt_labels_list
,
gt_bboxes_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pts_targets_list
,
pts_weights_list
,
num_total_pos
,
num_total_neg
)
def
loss_single
(
self
,
cls_scores
,
bbox_preds
,
pts_preds
,
gt_bboxes_list
,
gt_labels_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
=
None
):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_pts_list (list[Tensor]): Ground truth pts for each image
with shape (num_gts, fixed_num, 2) in [x,y] format.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs
=
cls_scores
.
size
(
0
)
cls_scores_list
=
[
cls_scores
[
i
]
for
i
in
range
(
num_imgs
)]
bbox_preds_list
=
[
bbox_preds
[
i
]
for
i
in
range
(
num_imgs
)]
pts_preds_list
=
[
pts_preds
[
i
]
for
i
in
range
(
num_imgs
)]
# import pdb;pdb.set_trace()
cls_reg_targets
=
self
.
get_targets
(
cls_scores_list
,
bbox_preds_list
,
pts_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
)
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pts_targets_list
,
pts_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
# import pdb;pdb.set_trace()
labels
=
torch
.
cat
(
labels_list
,
0
)
label_weights
=
torch
.
cat
(
label_weights_list
,
0
)
bbox_targets
=
torch
.
cat
(
bbox_targets_list
,
0
)
bbox_weights
=
torch
.
cat
(
bbox_weights_list
,
0
)
pts_targets
=
torch
.
cat
(
pts_targets_list
,
0
)
pts_weights
=
torch
.
cat
(
pts_weights_list
,
0
)
# classification loss
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_scores
.
new_tensor
([
cls_avg_factor
]))
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
loss_cls
=
self
.
loss_cls
(
cls_scores
,
labels
,
label_weights
,
avg_factor
=
cls_avg_factor
)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
# import pdb;pdb.set_trace()
# regression L1 loss
bbox_preds
=
bbox_preds
.
reshape
(
-
1
,
bbox_preds
.
size
(
-
1
))
normalized_bbox_targets
=
normalize_2d_bbox
(
bbox_targets
,
self
.
pc_range
)
# normalized_bbox_targets = bbox_targets
isnotnan
=
torch
.
isfinite
(
normalized_bbox_targets
).
all
(
dim
=-
1
)
bbox_weights
=
bbox_weights
*
self
.
code_weights
loss_bbox
=
self
.
loss_bbox
(
bbox_preds
[
isnotnan
,
:
4
],
normalized_bbox_targets
[
isnotnan
,
:
4
],
bbox_weights
[
isnotnan
,
:
4
],
avg_factor
=
num_total_pos
)
# regression pts CD loss
# pts_preds = pts_preds
# import pdb;pdb.set_trace()
# num_samples, num_order, num_pts, num_coords
normalized_pts_targets
=
normalize_2d_pts
(
pts_targets
,
self
.
pc_range
)
# num_samples, num_pts, num_coords
pts_preds
=
pts_preds
.
reshape
(
-
1
,
pts_preds
.
size
(
-
2
),
pts_preds
.
size
(
-
1
))
if
self
.
num_pts_per_vec
!=
self
.
num_pts_per_gt_vec
:
pts_preds
=
pts_preds
.
permute
(
0
,
2
,
1
)
pts_preds
=
F
.
interpolate
(
pts_preds
,
size
=
(
self
.
num_pts_per_gt_vec
),
mode
=
'linear'
,
align_corners
=
True
)
pts_preds
=
pts_preds
.
permute
(
0
,
2
,
1
).
contiguous
()
# import pdb;pdb.set_trace()
loss_pts
=
self
.
loss_pts
(
pts_preds
[
isnotnan
,:,:],
normalized_pts_targets
[
isnotnan
,
:,:],
pts_weights
[
isnotnan
,:,:],
avg_factor
=
num_total_pos
)
dir_weights
=
pts_weights
[:,
:
-
self
.
dir_interval
,
0
]
denormed_pts_preds
=
denormalize_2d_pts
(
pts_preds
,
self
.
pc_range
)
denormed_pts_preds_dir
=
denormed_pts_preds
[:,
self
.
dir_interval
:,:]
-
denormed_pts_preds
[:,:
-
self
.
dir_interval
,:]
pts_targets_dir
=
pts_targets
[:,
self
.
dir_interval
:,:]
-
pts_targets
[:,:
-
self
.
dir_interval
,:]
# dir_weights = pts_weights[:, indice,:-1,0]
# import pdb;pdb.set_trace()
loss_dir
=
self
.
loss_dir
(
denormed_pts_preds_dir
[
isnotnan
,:,:],
pts_targets_dir
[
isnotnan
,
:,:],
dir_weights
[
isnotnan
,:],
avg_factor
=
num_total_pos
)
bboxes
=
denormalize_2d_bbox
(
bbox_preds
,
self
.
pc_range
)
# regression IoU loss, defaultly GIoU loss
loss_iou
=
self
.
loss_iou
(
bboxes
[
isnotnan
,
:
4
],
bbox_targets
[
isnotnan
,
:
4
],
bbox_weights
[
isnotnan
,
:
4
],
avg_factor
=
num_total_pos
)
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
loss_cls
=
torch
.
nan_to_num
(
loss_cls
)
loss_bbox
=
torch
.
nan_to_num
(
loss_bbox
)
loss_iou
=
torch
.
nan_to_num
(
loss_iou
)
loss_pts
=
torch
.
nan_to_num
(
loss_pts
)
loss_dir
=
torch
.
nan_to_num
(
loss_dir
)
return
loss_cls
,
loss_bbox
,
loss_iou
,
loss_pts
,
loss_dir
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
loss
(
self
,
gt_bboxes_list
,
gt_labels_list
,
preds_dicts
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
gt_vecs_list
=
copy
.
deepcopy
(
gt_bboxes_list
)
# import pdb;pdb.set_trace()
all_cls_scores
=
preds_dicts
[
'all_cls_scores'
]
all_bbox_preds
=
preds_dicts
[
'all_bbox_preds'
]
all_pts_preds
=
preds_dicts
[
'all_pts_preds'
]
enc_cls_scores
=
preds_dicts
[
'enc_cls_scores'
]
enc_bbox_preds
=
preds_dicts
[
'enc_bbox_preds'
]
enc_pts_preds
=
preds_dicts
[
'enc_pts_preds'
]
num_dec_layers
=
len
(
all_cls_scores
)
device
=
gt_labels_list
[
0
].
device
# gt_bboxes_list = [torch.cat(
# (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
# dim=1).to(device) for gt_bboxes in gt_bboxes_list]
# import pdb;pdb.set_trace()
# gt_bboxes_list = [
# gt_bboxes.to(device) for gt_bboxes in gt_bboxes_list]
gt_bboxes_list
=
[
gt_bboxes
.
bbox
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
gt_pts_list
=
[
gt_bboxes
.
fixed_num_sampled_points
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
if
self
.
gt_shift_pts_pattern
==
'v0'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v1'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v1
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v2'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v2
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v3'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v3
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v4'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v4
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
else
:
raise
NotImplementedError
all_gt_bboxes_list
=
[
gt_bboxes_list
for
_
in
range
(
num_dec_layers
)]
all_gt_labels_list
=
[
gt_labels_list
for
_
in
range
(
num_dec_layers
)]
all_gt_pts_list
=
[
gt_pts_list
for
_
in
range
(
num_dec_layers
)]
all_gt_shifts_pts_list
=
[
gt_shifts_pts_list
for
_
in
range
(
num_dec_layers
)]
all_gt_bboxes_ignore_list
=
[
gt_bboxes_ignore
for
_
in
range
(
num_dec_layers
)
]
# import pdb;pdb.set_trace()
losses_cls
,
losses_bbox
,
losses_iou
,
losses_pts
,
losses_dir
=
multi_apply
(
self
.
loss_single
,
all_cls_scores
,
all_bbox_preds
,
all_pts_preds
,
all_gt_bboxes_list
,
all_gt_labels_list
,
all_gt_shifts_pts_list
,
all_gt_bboxes_ignore_list
)
loss_dict
=
dict
()
# loss of proposal generated from encode feature map.
if
enc_cls_scores
is
not
None
:
binary_labels_list
=
[
torch
.
zeros_like
(
gt_labels_list
[
i
])
for
i
in
range
(
len
(
all_gt_labels_list
))
]
# TODO bug here
enc_loss_cls
,
enc_losses_bbox
,
enc_losses_iou
,
enc_losses_pts
,
enc_losses_dir
=
\
self
.
loss_single
(
enc_cls_scores
,
enc_bbox_preds
,
enc_pts_preds
,
gt_bboxes_list
,
binary_labels_list
,
gt_pts_list
,
gt_bboxes_ignore
)
loss_dict
[
'enc_loss_cls'
]
=
enc_loss_cls
loss_dict
[
'enc_loss_bbox'
]
=
enc_losses_bbox
loss_dict
[
'enc_losses_iou'
]
=
enc_losses_iou
loss_dict
[
'enc_losses_pts'
]
=
enc_losses_pts
loss_dict
[
'enc_losses_dir'
]
=
enc_losses_dir
# loss from the last decoder layer
loss_dict
[
'loss_cls'
]
=
losses_cls
[
-
1
]
loss_dict
[
'loss_bbox'
]
=
losses_bbox
[
-
1
]
loss_dict
[
'loss_iou'
]
=
losses_iou
[
-
1
]
loss_dict
[
'loss_pts'
]
=
losses_pts
[
-
1
]
loss_dict
[
'loss_dir'
]
=
losses_dir
[
-
1
]
# loss from other decoder layers
num_dec_layer
=
0
for
loss_cls_i
,
loss_bbox_i
,
loss_iou_i
,
loss_pts_i
,
loss_dir_i
in
zip
(
losses_cls
[:
-
1
],
losses_bbox
[:
-
1
],
losses_iou
[:
-
1
],
losses_pts
[:
-
1
],
losses_dir
[:
-
1
]):
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_cls'
]
=
loss_cls_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_bbox'
]
=
loss_bbox_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_iou'
]
=
loss_iou_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_pts'
]
=
loss_pts_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_dir'
]
=
loss_dir_i
num_dec_layer
+=
1
return
loss_dict
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
get_bboxes
(
self
,
preds_dicts
,
img_metas
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
# bboxes: xmin, ymin, xmax, ymax
preds_dicts
=
self
.
bbox_coder
.
decode
(
preds_dicts
)
num_samples
=
len
(
preds_dicts
)
ret_list
=
[]
for
i
in
range
(
num_samples
):
preds
=
preds_dicts
[
i
]
bboxes
=
preds
[
'bboxes'
]
# bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
# code_size = bboxes.shape[-1]
# bboxes = img_metas[i]['box_type_3d'](bboxes, code_size)
scores
=
preds
[
'scores'
]
labels
=
preds
[
'labels'
]
pts
=
preds
[
'pts'
]
ret_list
.
append
([
bboxes
,
scores
,
labels
,
pts
])
return
ret_list
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/dense_heads/maptrv2_head.py
0 → 100644
View file @
19472568
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models
import
HEADS
,
build_loss
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet3d.core.bbox.coders
import
build_bbox_coder
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
xavier_init
,
constant_init
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.core.bbox.transforms
import
bbox_xyxy_to_cxcywh
,
bbox_cxcywh_to_xyxy
from
mmdet.core
import
(
multi_apply
,
multi_apply
,
reduce_mean
)
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
def
denormalize_3d_pts
(
pts
,
pc_range
):
new_pts
=
pts
.
clone
()
new_pts
[...,
0
:
1
]
=
(
pts
[...,
0
:
1
]
*
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
])
new_pts
[...,
1
:
2
]
=
(
pts
[...,
1
:
2
]
*
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
])
new_pts
[...,
2
:
3
]
=
(
pts
[...,
2
:
3
]
*
(
pc_range
[
5
]
-
pc_range
[
2
])
+
pc_range
[
2
])
return
new_pts
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
normalize_3d_pts
(
pts
,
pc_range
):
patch_h
=
pc_range
[
4
]
-
pc_range
[
1
]
patch_w
=
pc_range
[
3
]
-
pc_range
[
0
]
patch_z
=
pc_range
[
5
]
-
pc_range
[
2
]
new_pts
=
pts
.
clone
()
new_pts
[...,
0
:
1
]
=
pts
[...,
0
:
1
]
-
pc_range
[
0
]
new_pts
[...,
1
:
2
]
=
pts
[...,
1
:
2
]
-
pc_range
[
1
]
new_pts
[...,
2
:
3
]
=
pts
[...,
2
:
3
]
-
pc_range
[
2
]
factor
=
pts
.
new_tensor
([
patch_w
,
patch_h
,
patch_z
])
normalized_pts
=
new_pts
/
factor
return
normalized_pts
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
normalize_2d_bbox
(
bboxes
,
pc_range
):
patch_h
=
pc_range
[
4
]
-
pc_range
[
1
]
patch_w
=
pc_range
[
3
]
-
pc_range
[
0
]
cxcywh_bboxes
=
bbox_xyxy_to_cxcywh
(
bboxes
)
cxcywh_bboxes
[...,
0
:
1
]
=
cxcywh_bboxes
[...,
0
:
1
]
-
pc_range
[
0
]
cxcywh_bboxes
[...,
1
:
2
]
=
cxcywh_bboxes
[...,
1
:
2
]
-
pc_range
[
1
]
factor
=
bboxes
.
new_tensor
([
patch_w
,
patch_h
,
patch_w
,
patch_h
])
normalized_bboxes
=
cxcywh_bboxes
/
factor
return
normalized_bboxes
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
normalize_2d_pts
(
pts
,
pc_range
):
patch_h
=
pc_range
[
4
]
-
pc_range
[
1
]
patch_w
=
pc_range
[
3
]
-
pc_range
[
0
]
new_pts
=
pts
.
clone
()
new_pts
[...,
0
:
1
]
=
pts
[...,
0
:
1
]
-
pc_range
[
0
]
new_pts
[...,
1
:
2
]
=
pts
[...,
1
:
2
]
-
pc_range
[
1
]
factor
=
pts
.
new_tensor
([
patch_w
,
patch_h
])
normalized_pts
=
new_pts
/
factor
return
normalized_pts
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
denormalize_2d_bbox
(
bboxes
,
pc_range
):
bboxes
=
bbox_cxcywh_to_xyxy
(
bboxes
)
bboxes
[...,
0
::
2
]
=
(
bboxes
[...,
0
::
2
]
*
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
])
bboxes
[...,
1
::
2
]
=
(
bboxes
[...,
1
::
2
]
*
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
])
return
bboxes
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
denormalize_2d_pts
(
pts
,
pc_range
):
new_pts
=
pts
.
clone
()
new_pts
[...,
0
:
1
]
=
(
pts
[...,
0
:
1
]
*
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
])
new_pts
[...,
1
:
2
]
=
(
pts
[...,
1
:
2
]
*
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
])
return
new_pts
@
HEADS
.
register_module
()
class
MapTRv2Head
(
DETRHead
):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def
__init__
(
self
,
*
args
,
with_box_refine
=
False
,
as_two_stage
=
False
,
transformer
=
None
,
bbox_coder
=
None
,
num_cls_fcs
=
2
,
code_weights
=
None
,
bev_h
=
30
,
bev_w
=
30
,
# num_vec=20,
num_vec_one2one
=
50
,
num_vec_one2many
=
0
,
k_one2many
=
0
,
lambda_one2many
=
1
,
num_pts_per_vec
=
2
,
num_pts_per_gt_vec
=
2
,
query_embed_type
=
'all_pts'
,
transform_method
=
'minmax'
,
gt_shift_pts_pattern
=
'v0'
,
dir_interval
=
1
,
aux_seg
=
dict
(
use_aux_seg
=
False
,
bev_seg
=
False
,
pv_seg
=
False
,
seg_classes
=
1
,
feat_down_sample
=
32
,
),
z_cfg
=
dict
(
pred_z_flag
=
False
,
gt_z_flag
=
False
,
),
loss_pts
=
dict
(
type
=
'ChamferDistance'
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
),
loss_seg
=
dict
(
type
=
'SimpleLoss'
,
pos_weight
=
2.13
,
loss_weight
=
1.0
),
loss_pv_seg
=
dict
(
type
=
'SimpleLoss'
,
pos_weight
=
2.13
,
loss_weight
=
1.0
),
loss_dir
=
dict
(
type
=
'PtsDirCosLoss'
,
loss_weight
=
2.0
),
**
kwargs
):
self
.
bev_h
=
bev_h
self
.
bev_w
=
bev_w
self
.
fp16_enabled
=
False
self
.
with_box_refine
=
with_box_refine
self
.
as_two_stage
=
as_two_stage
self
.
bev_encoder_type
=
transformer
.
encoder
.
type
if
self
.
as_two_stage
:
transformer
[
'as_two_stage'
]
=
self
.
as_two_stage
if
'code_size'
in
kwargs
:
self
.
code_size
=
2
if
not
z_cfg
[
'pred_z_flag'
]
else
3
else
:
self
.
code_size
=
2
if
code_weights
is
not
None
:
self
.
code_weights
=
code_weights
else
:
self
.
code_weights
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
0.2
,
0.2
]
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
pc_range
=
self
.
bbox_coder
.
pc_range
self
.
real_w
=
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
]
self
.
real_h
=
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
]
self
.
num_cls_fcs
=
num_cls_fcs
-
1
self
.
query_embed_type
=
query_embed_type
self
.
transform_method
=
transform_method
self
.
gt_shift_pts_pattern
=
gt_shift_pts_pattern
num_vec
=
num_vec_one2one
+
num_vec_one2many
num_query
=
num_vec
*
num_pts_per_vec
self
.
num_query
=
num_query
self
.
num_vec
=
num_vec
self
.
num_pts_per_vec
=
num_pts_per_vec
self
.
num_pts_per_gt_vec
=
num_pts_per_gt_vec
self
.
dir_interval
=
dir_interval
self
.
aux_seg
=
aux_seg
self
.
z_cfg
=
z_cfg
super
(
MapTRv2Head
,
self
).
__init__
(
*
args
,
transformer
=
transformer
,
**
kwargs
)
self
.
code_weights
=
nn
.
Parameter
(
torch
.
tensor
(
self
.
code_weights
,
requires_grad
=
False
),
requires_grad
=
False
)
self
.
loss_pts
=
build_loss
(
loss_pts
)
self
.
loss_dir
=
build_loss
(
loss_dir
)
num_query
=
num_vec
*
num_pts_per_vec
self
.
num_query
=
num_query
self
.
num_vec
=
num_vec
self
.
num_pts_per_vec
=
num_pts_per_vec
self
.
num_pts_per_gt_vec
=
num_pts_per_gt_vec
self
.
num_vec_one2one
=
num_vec_one2one
self
.
num_vec_one2many
=
num_vec_one2many
self
.
k_one2many
=
k_one2many
self
.
lambda_one2many
=
lambda_one2many
self
.
loss_seg
=
build_loss
(
loss_seg
)
self
.
loss_pv_seg
=
build_loss
(
loss_pv_seg
)
self
.
_init_layers
()
def
_init_layers
(
self
):
"""Initialize classification branch and regression branch of head."""
cls_branch
=
[]
# cls_branch.append(Linear(self.embed_dims * 2, self.embed_dims))
# cls_branch.append(nn.LayerNorm(self.embed_dims))
# cls_branch.append(nn.ReLU(inplace=True))
for
_
in
range
(
self
.
num_reg_fcs
):
cls_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
cls_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
cls_branch
.
append
(
nn
.
ReLU
(
inplace
=
True
))
cls_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
cls_out_channels
))
fc_cls
=
nn
.
Sequential
(
*
cls_branch
)
reg_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
ReLU
())
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
code_size
))
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
def
_get_clones
(
module
,
N
):
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred
=
(
self
.
transformer
.
decoder
.
num_layers
+
1
)
if
\
self
.
as_two_stage
else
self
.
transformer
.
decoder
.
num_layers
if
self
.
with_box_refine
:
self
.
cls_branches
=
_get_clones
(
fc_cls
,
num_pred
)
self
.
reg_branches
=
_get_clones
(
reg_branch
,
num_pred
)
else
:
self
.
cls_branches
=
nn
.
ModuleList
(
[
fc_cls
for
_
in
range
(
num_pred
)])
self
.
reg_branches
=
nn
.
ModuleList
(
[
reg_branch
for
_
in
range
(
num_pred
)])
if
self
.
aux_seg
[
'use_aux_seg'
]:
if
not
(
self
.
aux_seg
[
'bev_seg'
]
or
self
.
aux_seg
[
'pv_seg'
]):
raise
ValueError
(
'aux_seg must have bev_seg or pv_seg'
)
if
self
.
aux_seg
[
'bev_seg'
]:
self
.
seg_head
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
embed_dims
,
self
.
embed_dims
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
# nn.BatchNorm2d(128),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
self
.
embed_dims
,
self
.
aux_seg
[
'seg_classes'
],
kernel_size
=
1
,
padding
=
0
)
)
if
self
.
aux_seg
[
'pv_seg'
]:
self
.
pv_seg_head
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
embed_dims
,
self
.
embed_dims
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
# nn.BatchNorm2d(128),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
self
.
embed_dims
,
self
.
aux_seg
[
'seg_classes'
],
kernel_size
=
1
,
padding
=
0
)
)
if
not
self
.
as_two_stage
:
if
'BEVFormerEncoder'
in
self
.
bev_encoder_type
:
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
else
:
self
.
bev_embedding
=
None
if
self
.
query_embed_type
==
'all_pts'
:
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
*
2
)
elif
self
.
query_embed_type
==
'instance_pts'
:
self
.
query_embedding
=
None
self
.
instance_embedding
=
nn
.
Embedding
(
self
.
num_vec
,
self
.
embed_dims
*
2
)
self
.
pts_embedding
=
nn
.
Embedding
(
self
.
num_pts_per_vec
,
self
.
embed_dims
*
2
)
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
self
.
transformer
.
init_weights
()
if
self
.
loss_cls
.
use_sigmoid
:
bias_init
=
bias_init_with_prob
(
0.01
)
for
m
in
self
.
cls_branches
:
nn
.
init
.
constant_
(
m
[
-
1
].
bias
,
bias_init
)
# for m in self.reg_branches:
# constant_init(m[-1], 0, bias=0)
# nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], 0.)
#@torch.compile(mode="max-autotune-no-cudagraphs") ####
def
compute_decoder_predictions
(
self
,
outputs
,
bs
,
num_vec
,
mlvl_feats
):
bev_embed
,
depth
,
hs
,
init_reference
,
inter_references
=
outputs
hs
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs_classes_one2one
=
[]
outputs_coords_one2one
=
[]
outputs_pts_coords_one2one
=
[]
outputs_classes_one2many
=
[]
outputs_coords_one2many
=
[]
outputs_pts_coords_one2many
=
[]
for
lvl
in
range
(
hs
.
shape
[
0
]):
if
lvl
==
0
:
# import pdb;pdb.set_trace()
reference
=
init_reference
[...,
0
:
2
]
if
not
self
.
z_cfg
[
'gt_z_flag'
]
else
init_reference
[...,
0
:
3
]
else
:
reference
=
inter_references
[
lvl
-
1
][...,
0
:
2
]
if
not
self
.
z_cfg
[
'gt_z_flag'
]
else
inter_references
[
lvl
-
1
][...,
0
:
3
]
reference
=
inverse_sigmoid
(
reference
)
# import pdb;pdb.set_trace()
# vec_embedding = hs[lvl].reshape(bs, self.num_vec, -1)
outputs_class
=
self
.
cls_branches
[
lvl
](
hs
[
lvl
]
.
view
(
bs
,
num_vec
,
self
.
num_pts_per_vec
,
-
1
)
.
mean
(
2
))
tmp
=
self
.
reg_branches
[
lvl
](
hs
[
lvl
])
tmp
=
tmp
[...,
0
:
2
]
if
not
self
.
z_cfg
[
'gt_z_flag'
]
else
tmp
[...,
0
:
3
]
# TODO: check the shape of reference
# assert reference.shape[-1] == 2
# tmp[..., 0:2] += reference[..., 0:2]
# assert reference.shape[-1] == 2
tmp
+=
reference
tmp
=
tmp
.
sigmoid
()
# cx,cy,w,h
# if not self.z_cfg['gt_z_flag']:
# tmp = tmp[..., 0:2] if not self.z_cfg['gt_z_flag'] else tmp[..., 0:3]
# TODO: check if using sigmoid
outputs_coord
,
outputs_pts_coord
=
self
.
transform_box
(
tmp
,
num_vec
=
num_vec
)
outputs_classes_one2one
.
append
(
outputs_class
[:,
0
:
self
.
num_vec_one2one
])
outputs_coords_one2one
.
append
(
outputs_coord
[:,
0
:
self
.
num_vec_one2one
])
outputs_pts_coords_one2one
.
append
(
outputs_pts_coord
[:,
0
:
self
.
num_vec_one2one
])
outputs_classes_one2many
.
append
(
outputs_class
[:,
self
.
num_vec_one2one
:])
outputs_coords_one2many
.
append
(
outputs_coord
[:,
self
.
num_vec_one2one
:])
outputs_pts_coords_one2many
.
append
(
outputs_pts_coord
[:,
self
.
num_vec_one2one
:])
outputs_seg
=
None
outputs_pv_seg
=
None
if
self
.
aux_seg
[
'use_aux_seg'
]:
seg_bev_embed
=
bev_embed
.
permute
(
1
,
0
,
2
).
view
(
bs
,
self
.
bev_h
,
self
.
bev_w
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
self
.
aux_seg
[
'bev_seg'
]:
outputs_seg
=
self
.
seg_head
(
seg_bev_embed
)
bs
,
num_cam
,
embed_dims
,
feat_h
,
feat_w
=
mlvl_feats
[
-
1
].
shape
if
self
.
aux_seg
[
'pv_seg'
]:
outputs_pv_seg
=
self
.
pv_seg_head
(
mlvl_feats
[
-
1
].
flatten
(
0
,
1
))
outputs_pv_seg
=
outputs_pv_seg
.
view
(
bs
,
num_cam
,
-
1
,
feat_h
,
feat_w
)
return
bev_embed
,
outputs_classes_one2one
,
outputs_coords_one2one
,
outputs_pts_coords_one2one
,
depth
,
outputs_seg
,
outputs_pv_seg
,
outputs_classes_one2many
,
outputs_coords_one2many
,
outputs_pts_coords_one2many
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
prepare_transformer_inputs
(
self
,
mlvl_feats
):
if
self
.
training
:
num_vec
=
self
.
num_vec
else
:
num_vec
=
self
.
num_vec_one2one
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
if
self
.
query_embed_type
==
'all_pts'
:
object_query_embeds
=
self
.
query_embedding
.
weight
.
to
(
dtype
)
elif
self
.
query_embed_type
==
'instance_pts'
:
pts_embeds
=
self
.
pts_embedding
.
weight
.
unsqueeze
(
0
)
instance_embeds
=
self
.
instance_embedding
.
weight
[
0
:
num_vec
].
unsqueeze
(
1
)
object_query_embeds
=
(
pts_embeds
+
instance_embeds
).
flatten
(
0
,
1
).
to
(
dtype
)
if
self
.
bev_embedding
is
not
None
:
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
else
:
bev_queries
=
None
bev_mask
=
None
bev_pos
=
None
# make attn mask
""" attention mask to prevent information leakage
"""
self_attn_mask
=
(
torch
.
zeros
([
num_vec
,
num_vec
,]).
bool
().
to
(
mlvl_feats
[
0
].
device
)
)
self_attn_mask
[
self
.
num_vec_one2one
:,
0
:
self
.
num_vec_one2one
,]
=
True
self_attn_mask
[
0
:
self
.
num_vec_one2one
,
self
.
num_vec_one2one
:,]
=
True
return
num_vec
,
object_query_embeds
,
bev_queries
,
bev_pos
,
self_attn_mask
,
bs
# @auto_fp16(apply_to=('mlvl_feats'))
@
force_fp32
(
apply_to
=
(
'mlvl_feats'
,
'prev_bev'
))
def
forward
(
self
,
mlvl_feats
,
lidar_feat
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
\
shape [nb_dec, bs, num_query, cls_out_channels]. Note
\
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression
\
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy).
\
Shape [nb_dec, bs, num_query, 9].
"""
# if self.training:
# num_vec = self.num_vec
# else:
# num_vec = self.num_vec_one2one
# # import ipdb;ipdb.set_trace()
# bs, num_cam, _, _, _ = mlvl_feats[0].shape
# dtype = mlvl_feats[0].dtype
# # import ipdb;ipdb.set_trace()
# if self.query_embed_type == 'all_pts':
# object_query_embeds = self.query_embedding.weight.to(dtype)
# elif self.query_embed_type == 'instance_pts':
# pts_embeds = self.pts_embedding.weight.unsqueeze(0)
# instance_embeds = self.instance_embedding.weight[0:num_vec].unsqueeze(1)
# object_query_embeds = (pts_embeds + instance_embeds).flatten(0, 1).to(dtype)
# if self.bev_embedding is not None:
# bev_queries = self.bev_embedding.weight.to(dtype)
# bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
# device=bev_queries.device).to(dtype)
# bev_pos = self.positional_encoding(bev_mask).to(dtype)
# else:
# bev_queries = None
# bev_mask = None
# bev_pos = None
# # make attn mask
# """ attention mask to prevent information leakage
# """
# self_attn_mask = (
# torch.zeros([num_vec, num_vec,]).bool().to(mlvl_feats[0].device)
# )
# self_attn_mask[self.num_vec_one2one :, 0 : self.num_vec_one2one,] = True
# self_attn_mask[0 : self.num_vec_one2one, self.num_vec_one2one :,] = True
num_vec
,
object_query_embeds
,
bev_queries
,
bev_pos
,
self_attn_mask
,
bs
=
self
.
prepare_transformer_inputs
(
mlvl_feats
)
if
only_bev
:
# only use encoder to obtain BEV features, TODO: refine the workaround
return
self
.
transformer
.
get_bev_features
(
mlvl_feats
,
lidar_feat
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)[
'bev'
]
else
:
outputs
=
self
.
transformer
(
mlvl_feats
,
lidar_feat
,
bev_queries
,
object_query_embeds
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
reg_branches
=
self
.
reg_branches
if
self
.
with_box_refine
else
None
,
# noqa:E501
cls_branches
=
self
.
cls_branches
if
self
.
as_two_stage
else
None
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
self_attn_mask
=
self_attn_mask
,
num_vec
=
num_vec
,
num_pts_per_vec
=
self
.
num_pts_per_vec
,
)
bev_embed
,
outputs_classes_one2one
,
outputs_coords_one2one
,
outputs_pts_coords_one2one
,
depth
,
outputs_seg
,
outputs_pv_seg
,
outputs_classes_one2many
,
outputs_coords_one2many
,
outputs_pts_coords_one2many
=
self
.
compute_decoder_predictions
(
outputs
,
bs
,
num_vec
,
mlvl_feats
)
outputs_classes_one2one
=
torch
.
stack
(
outputs_classes_one2one
)
outputs_coords_one2one
=
torch
.
stack
(
outputs_coords_one2one
)
outputs_pts_coords_one2one
=
torch
.
stack
(
outputs_pts_coords_one2one
)
outputs_classes_one2many
=
torch
.
stack
(
outputs_classes_one2many
)
outputs_coords_one2many
=
torch
.
stack
(
outputs_coords_one2many
)
outputs_pts_coords_one2many
=
torch
.
stack
(
outputs_pts_coords_one2many
)
outs
=
{
'bev_embed'
:
bev_embed
,
'all_cls_scores'
:
outputs_classes_one2one
,
'all_bbox_preds'
:
outputs_coords_one2one
,
'all_pts_preds'
:
outputs_pts_coords_one2one
,
'enc_cls_scores'
:
None
,
'enc_bbox_preds'
:
None
,
'enc_pts_preds'
:
None
,
'depth'
:
depth
,
'seg'
:
outputs_seg
,
'pv_seg'
:
outputs_pv_seg
,
"one2many_outs"
:
dict
(
all_cls_scores
=
outputs_classes_one2many
,
all_bbox_preds
=
outputs_coords_one2many
,
all_pts_preds
=
outputs_pts_coords_one2many
,
enc_cls_scores
=
None
,
enc_bbox_preds
=
None
,
enc_pts_preds
=
None
,
seg
=
None
,
pv_seg
=
None
,
)
}
return
outs
def
transform_box
(
self
,
pts
,
num_vec
=
50
,
y_first
=
False
):
"""
Converting the points set into bounding box.
Args:
pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
Returns:
The bbox [cx, cy, w, h] transformed from points.
"""
if
self
.
z_cfg
[
'gt_z_flag'
]:
pts_reshape
=
pts
.
view
(
pts
.
shape
[
0
],
num_vec
,
self
.
num_pts_per_vec
,
3
)
else
:
pts_reshape
=
pts
.
view
(
pts
.
shape
[
0
],
num_vec
,
self
.
num_pts_per_vec
,
2
)
pts_y
=
pts_reshape
[:,
:,
:,
0
]
if
y_first
else
pts_reshape
[:,
:,
:,
1
]
pts_x
=
pts_reshape
[:,
:,
:,
1
]
if
y_first
else
pts_reshape
[:,
:,
:,
0
]
if
self
.
transform_method
==
'minmax'
:
# import pdb;pdb.set_trace()
xmin
=
pts_x
.
min
(
dim
=
2
,
keepdim
=
True
)[
0
]
xmax
=
pts_x
.
max
(
dim
=
2
,
keepdim
=
True
)[
0
]
ymin
=
pts_y
.
min
(
dim
=
2
,
keepdim
=
True
)[
0
]
ymax
=
pts_y
.
max
(
dim
=
2
,
keepdim
=
True
)[
0
]
bbox
=
torch
.
cat
([
xmin
,
ymin
,
xmax
,
ymax
],
dim
=
2
)
bbox
=
bbox_xyxy_to_cxcywh
(
bbox
)
else
:
raise
NotImplementedError
return
bbox
,
pts_reshape
def
get_label_result
(
self
,
sampling_result
,
gt_bboxes
,
gt_labels
,
bbox_pred
,
order_index
,
pts_pred
,
gt_shifts_pts
):
num_bboxes
=
bbox_pred
.
size
(
0
)
gt_c
=
gt_bboxes
.
shape
[
-
1
]
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
pos_gt_bboxes
=
sampling_result
.
pos_gt_bboxes
# label targets
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,),
self
.
num_classes
,
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
pos_assigned_gt_inds
]
label_weights
=
gt_bboxes
.
new_ones
(
num_bboxes
)
# bbox targets
bbox_targets
=
torch
.
zeros_like
(
bbox_pred
)[...,
:
gt_c
]
bbox_weights
=
torch
.
zeros_like
(
bbox_pred
)
bbox_weights
[
pos_inds
]
=
1.0
if
order_index
is
None
:
assigned_shift
=
gt_labels
[
pos_assigned_gt_inds
]
else
:
assigned_shift
=
order_index
[
pos_inds
,
pos_assigned_gt_inds
]
pts_targets
=
pts_pred
.
new_zeros
((
pts_pred
.
size
(
0
),
pts_pred
.
size
(
1
),
pts_pred
.
size
(
2
)))
pts_weights
=
torch
.
zeros_like
(
pts_targets
)
pts_weights
[
pos_inds
]
=
1.0
# DETR
bbox_targets
[
pos_inds
]
=
pos_gt_bboxes
pts_targets
[
pos_inds
]
=
gt_shifts_pts
[
pos_assigned_gt_inds
,
assigned_shift
,:,:]
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pts_targets
,
pts_weights
,
pos_inds
,
neg_inds
def
_get_target_single
(
self
,
cls_score
,
bbox_pred
,
pts_pred
,
gt_labels
,
gt_bboxes
,
gt_shifts_pts
,
gt_bboxes_ignore
=
None
):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
# num_bboxes = bbox_pred.size(0)
# # assigner and sampler
# gt_c = gt_bboxes.shape[-1]
assign_result
,
order_index
=
self
.
assigner
.
assign
(
bbox_pred
,
cls_score
,
pts_pred
,
gt_bboxes
,
gt_labels
,
gt_shifts_pts
,
gt_bboxes_ignore
)
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
bbox_pred
,
gt_bboxes
[
0
])
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pts_targets
,
pts_weights
,
pos_inds
,
neg_inds
=
self
.
get_label_result
(
sampling_result
,
gt_bboxes
[
0
],
gt_labels
[
0
],
bbox_pred
,
order_index
,
pts_pred
,
gt_shifts_pts
[
0
])
# sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
# labels, label_weights, bbox_targets, bbox_weights, pts_targets, pts_weights, pos_inds, neg_inds = self.get_label_result(sampling_result, gt_bboxes, gt_labels, bbox_pred, order_index, pts_pred, gt_shifts_pts)
return
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pts_targets
,
pts_weights
,
pos_inds
,
neg_inds
)
def
get_targets
(
self
,
cls_scores_list
,
bbox_preds_list
,
pts_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
=
None
):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all
\
images.
- bbox_targets_list (list[Tensor]): BBox targets for all
\
images.
- bbox_weights_list (list[Tensor]): BBox weights for all
\
images.
- num_total_pos (int): Number of positive samples in all
\
images.
- num_total_neg (int): Number of negative samples in all
\
images.
"""
assert
gt_bboxes_ignore_list
is
None
,
\
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs
=
len
(
cls_scores_list
)
gt_bboxes_ignore_list
=
[
gt_bboxes_ignore_list
for
_
in
range
(
num_imgs
)]
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pts_targets_list
,
pts_weights_list
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
pts_preds_list
,
gt_labels_list
,
gt_bboxes_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pts_targets_list
,
pts_weights_list
,
num_total_pos
,
num_total_neg
)
def
loss_single
(
self
,
cls_scores
,
bbox_preds
,
pts_preds
,
gt_bboxes_list
,
gt_labels_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
=
None
):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_pts_list (list[Tensor]): Ground truth pts for each image
with shape (num_gts, fixed_num, 2) in [x,y] format.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs
=
cls_scores
.
size
(
0
)
cls_scores_list
=
[
cls_scores
[
i
]
for
i
in
range
(
num_imgs
)]
bbox_preds_list
=
[
bbox_preds
[
i
]
for
i
in
range
(
num_imgs
)]
pts_preds_list
=
[
pts_preds
[
i
]
for
i
in
range
(
num_imgs
)]
# import pdb;pdb.set_trace()
cls_reg_targets
=
self
.
get_targets
(
cls_scores_list
,
bbox_preds_list
,
pts_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
gt_shifts_pts_list
,
gt_bboxes_ignore_list
)
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pts_targets_list
,
pts_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
# import pdb;pdb.set_trace()
labels
=
torch
.
cat
(
labels_list
,
0
)
label_weights
=
torch
.
cat
(
label_weights_list
,
0
)
bbox_targets
=
torch
.
cat
(
bbox_targets_list
,
0
)
bbox_weights
=
torch
.
cat
(
bbox_weights_list
,
0
)
pts_targets
=
torch
.
cat
(
pts_targets_list
,
0
)
pts_weights
=
torch
.
cat
(
pts_weights_list
,
0
)
# classification loss
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_scores
.
new_tensor
([
cls_avg_factor
]))
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
loss_cls
=
self
.
loss_cls
(
cls_scores
,
labels
,
label_weights
,
avg_factor
=
cls_avg_factor
)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
# import pdb;pdb.set_trace()
# regression L1 loss
bbox_preds
=
bbox_preds
.
reshape
(
-
1
,
bbox_preds
.
size
(
-
1
))
normalized_bbox_targets
=
normalize_2d_bbox
(
bbox_targets
,
self
.
pc_range
)
# normalized_bbox_targets = bbox_targets
isnotnan
=
torch
.
isfinite
(
normalized_bbox_targets
).
all
(
dim
=-
1
)
bbox_weights
=
bbox_weights
*
self
.
code_weights
loss_bbox
=
self
.
loss_bbox
(
bbox_preds
[
isnotnan
,
:
4
],
normalized_bbox_targets
[
isnotnan
,
:
4
],
bbox_weights
[
isnotnan
,
:
4
],
avg_factor
=
num_total_pos
)
# regression pts CD loss
# pts_preds = pts_preds
# import pdb;pdb.set_trace()
# num_samples, num_order, num_pts, num_coords
normalized_pts_targets
=
normalize_2d_pts
(
pts_targets
,
self
.
pc_range
)
if
not
self
.
z_cfg
[
'gt_z_flag'
]
\
else
normalize_3d_pts
(
pts_targets
,
self
.
pc_range
)
# num_samples, num_pts, num_coords
pts_preds
=
pts_preds
.
reshape
(
-
1
,
pts_preds
.
size
(
-
2
),
pts_preds
.
size
(
-
1
))
if
self
.
num_pts_per_vec
!=
self
.
num_pts_per_gt_vec
:
pts_preds
=
pts_preds
.
permute
(
0
,
2
,
1
)
pts_preds
=
F
.
interpolate
(
pts_preds
,
size
=
(
self
.
num_pts_per_gt_vec
),
mode
=
'linear'
,
align_corners
=
True
)
pts_preds
=
pts_preds
.
permute
(
0
,
2
,
1
).
contiguous
()
# import pdb;pdb.set_trace()
loss_pts
=
self
.
loss_pts
(
pts_preds
[
isnotnan
,:,:],
normalized_pts_targets
[
isnotnan
,
:,:],
pts_weights
[
isnotnan
,:,:],
avg_factor
=
num_total_pos
)
dir_weights
=
pts_weights
[:,
:
-
self
.
dir_interval
,
0
]
denormed_pts_preds
=
denormalize_2d_pts
(
pts_preds
,
self
.
pc_range
)
if
not
self
.
z_cfg
[
'gt_z_flag'
]
\
else
denormalize_3d_pts
(
pts_preds
,
self
.
pc_range
)
denormed_pts_preds_dir
=
denormed_pts_preds
[:,
self
.
dir_interval
:,:]
-
denormed_pts_preds
[:,:
-
self
.
dir_interval
,:]
pts_targets_dir
=
pts_targets
[:,
self
.
dir_interval
:,:]
-
pts_targets
[:,:
-
self
.
dir_interval
,:]
# dir_weights = pts_weights[:, indice,:-1,0]
# import pdb;pdb.set_trace()
loss_dir
=
self
.
loss_dir
(
denormed_pts_preds_dir
[
isnotnan
,:,:],
pts_targets_dir
[
isnotnan
,
:,:],
dir_weights
[
isnotnan
,:],
avg_factor
=
num_total_pos
)
bboxes
=
denormalize_2d_bbox
(
bbox_preds
,
self
.
pc_range
)
# regression IoU loss, defaultly GIoU loss
loss_iou
=
self
.
loss_iou
(
bboxes
[
isnotnan
,
:
4
],
bbox_targets
[
isnotnan
,
:
4
],
bbox_weights
[
isnotnan
,
:
4
],
avg_factor
=
num_total_pos
)
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
loss_cls
=
torch
.
nan_to_num
(
loss_cls
)
loss_bbox
=
torch
.
nan_to_num
(
loss_bbox
)
loss_iou
=
torch
.
nan_to_num
(
loss_iou
)
loss_pts
=
torch
.
nan_to_num
(
loss_pts
)
loss_dir
=
torch
.
nan_to_num
(
loss_dir
)
return
loss_cls
,
loss_bbox
,
loss_iou
,
loss_pts
,
loss_dir
import
torch
def
pad_to_static_list
(
self
,
tensors
,
pad_value
=
0
,
device
=
None
):
# max_len = max(t.size(0) for t in tensors)
max_len
=
200
results
=
[]
for
t
in
tensors
:
pad_shape
=
(
max_len
,)
+
t
.
shape
[
1
:]
out
=
torch
.
full
(
pad_shape
,
pad_value
,
device
=
device
,
dtype
=
t
.
dtype
)
mask
=
torch
.
zeros
(
max_len
,
dtype
=
torch
.
bool
,
device
=
device
)
length
=
t
.
size
(
0
)
out
[:
length
,
...]
=
t
mask
[:
length
]
=
1
results
.
append
((
out
,
mask
,
length
))
return
results
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
loss
(
self
,
gt_bboxes_list
,
gt_labels_list
,
gt_seg_mask
,
gt_pv_seg_mask
,
preds_dicts
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
gt_vecs_list
=
copy
.
deepcopy
(
gt_bboxes_list
)
all_cls_scores
=
preds_dicts
[
'all_cls_scores'
]
all_bbox_preds
=
preds_dicts
[
'all_bbox_preds'
]
all_pts_preds
=
preds_dicts
[
'all_pts_preds'
]
enc_cls_scores
=
preds_dicts
[
'enc_cls_scores'
]
enc_bbox_preds
=
preds_dicts
[
'enc_bbox_preds'
]
enc_pts_preds
=
preds_dicts
[
'enc_pts_preds'
]
num_dec_layers
=
len
(
all_cls_scores
)
device
=
gt_labels_list
[
0
].
device
gt_bboxes_list
=
[
gt_bboxes
.
bbox
.
to
(
device
)
for
gt_bboxes
in
gt_vecs_list
]
# gt_pts_list = [gt_bboxes.fixed_num_sampled_points.to(device) for gt_bboxes in gt_vecs_list]
# if self.gt_shift_pts_pattern == 'v0':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v1':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v1.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v2':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v2.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v3':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v3.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v4':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v4.to(device) for gt_bboxes in gt_vecs_list]
# else:
# raise NotImplementedError
# all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
# all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
# all_gt_shifts_pts_list = [gt_shifts_pts_list for _ in range(num_dec_layers)]
# all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
# # all_gt_pts_list = [gt_pts_list for _ in range(num_dec_layers)]
if
self
.
gt_shift_pts_pattern
==
'v0'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v1'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v1
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v2'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v2
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v3'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v3
for
gt_bboxes
in
gt_vecs_list
]
elif
self
.
gt_shift_pts_pattern
==
'v4'
:
gt_shifts_pts_list
=
[
gt_bboxes
.
shift_fixed_num_sampled_points_v4
for
gt_bboxes
in
gt_vecs_list
]
else
:
raise
NotImplementedError
all_gt_bboxes
=
self
.
pad_to_static_list
(
gt_bboxes_list
,
device
=
device
)
all_shifts_pts
=
self
.
pad_to_static_list
(
gt_shifts_pts_list
,
device
=
device
)
all_gt_labels
=
self
.
pad_to_static_list
(
gt_labels_list
,
device
=
device
)
all_gt_bboxes_list
=
[
all_gt_bboxes
for
_
in
range
(
num_dec_layers
)]
all_gt_shifts_pts_list
=
[
all_shifts_pts
for
_
in
range
(
num_dec_layers
)]
all_gt_labels_list
=
[
all_gt_labels
for
_
in
range
(
num_dec_layers
)]
all_gt_bboxes_ignore_list
=
[
gt_bboxes_ignore
for
_
in
range
(
num_dec_layers
)]
# all_gt_pts_list = [gt_pts_list for _ in range(num_dec_layers)]
losses_cls
,
losses_bbox
,
losses_iou
,
losses_pts
,
losses_dir
=
multi_apply
(
self
.
loss_single
,
all_cls_scores
,
all_bbox_preds
,
all_pts_preds
,
all_gt_bboxes_list
,
all_gt_labels_list
,
all_gt_shifts_pts_list
,
all_gt_bboxes_ignore_list
)
loss_dict
=
dict
()
if
self
.
aux_seg
[
'use_aux_seg'
]:
# import ipdb;ipdb.set_trace()
if
self
.
aux_seg
[
'bev_seg'
]:
if
preds_dicts
[
'seg'
]
is
not
None
:
seg_output
=
preds_dicts
[
'seg'
]
num_imgs
=
seg_output
.
size
(
0
)
seg_gt
=
torch
.
stack
([
gt_seg_mask
[
i
]
for
i
in
range
(
num_imgs
)],
dim
=
0
)
loss_seg
=
self
.
loss_seg
(
seg_output
,
seg_gt
.
float
())
loss_dict
[
'loss_seg'
]
=
loss_seg
if
self
.
aux_seg
[
'pv_seg'
]:
# import ipdb;ipdb.set_trace()
if
preds_dicts
[
'pv_seg'
]
is
not
None
:
pv_seg_output
=
preds_dicts
[
'pv_seg'
]
num_imgs
=
pv_seg_output
.
size
(
0
)
pv_seg_gt
=
torch
.
stack
([
gt_pv_seg_mask
[
i
]
for
i
in
range
(
num_imgs
)],
dim
=
0
)
loss_pv_seg
=
self
.
loss_pv_seg
(
pv_seg_output
,
pv_seg_gt
.
float
())
loss_dict
[
'loss_pv_seg'
]
=
loss_pv_seg
# loss of proposal generated from encode feature map.
if
enc_cls_scores
is
not
None
:
binary_labels_list
=
[
torch
.
zeros_like
(
gt_labels_list
[
i
])
for
i
in
range
(
len
(
all_gt_labels_list
))
]
# TODO bug here
enc_loss_cls
,
enc_losses_bbox
,
enc_losses_iou
,
enc_losses_pts
,
enc_losses_dir
=
\
self
.
loss_single
(
enc_cls_scores
,
enc_bbox_preds
,
enc_pts_preds
,
gt_bboxes_list
,
binary_labels_list
,
gt_pts_list
,
gt_bboxes_ignore
)
loss_dict
[
'enc_loss_cls'
]
=
enc_loss_cls
loss_dict
[
'enc_loss_bbox'
]
=
enc_losses_bbox
loss_dict
[
'enc_losses_iou'
]
=
enc_losses_iou
loss_dict
[
'enc_losses_pts'
]
=
enc_losses_pts
loss_dict
[
'enc_losses_dir'
]
=
enc_losses_dir
# loss from the last decoder layer
loss_dict
[
'loss_cls'
]
=
losses_cls
[
-
1
]
loss_dict
[
'loss_bbox'
]
=
losses_bbox
[
-
1
]
loss_dict
[
'loss_iou'
]
=
losses_iou
[
-
1
]
loss_dict
[
'loss_pts'
]
=
losses_pts
[
-
1
]
loss_dict
[
'loss_dir'
]
=
losses_dir
[
-
1
]
# loss from other decoder layers
num_dec_layer
=
0
for
loss_cls_i
,
loss_bbox_i
,
loss_iou_i
,
loss_pts_i
,
loss_dir_i
in
zip
(
losses_cls
[:
-
1
],
losses_bbox
[:
-
1
],
losses_iou
[:
-
1
],
losses_pts
[:
-
1
],
losses_dir
[:
-
1
]):
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_cls'
]
=
loss_cls_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_bbox'
]
=
loss_bbox_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_iou'
]
=
loss_iou_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_pts'
]
=
loss_pts_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_dir'
]
=
loss_dir_i
num_dec_layer
+=
1
return
loss_dict
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
get_bboxes
(
self
,
preds_dicts
,
img_metas
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
# bboxes: xmin, ymin, xmax, ymax
preds_dicts
=
self
.
bbox_coder
.
decode
(
preds_dicts
)
num_samples
=
len
(
preds_dicts
)
ret_list
=
[]
for
i
in
range
(
num_samples
):
preds
=
preds_dicts
[
i
]
bboxes
=
preds
[
'bboxes'
]
# bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
# code_size = bboxes.shape[-1]
# bboxes = img_metas[i]['box_type_3d'](bboxes, code_size)
scores
=
preds
[
'scores'
]
labels
=
preds
[
'labels'
]
pts
=
preds
[
'pts'
]
ret_list
.
append
([
bboxes
,
scores
,
labels
,
pts
])
return
ret_list
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/detectors/__init__.py
0 → 100644
View file @
19472568
from
.maptr
import
MapTR
from
.maptrv2
import
MapTRv2
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/detectors/maptr.py
0 → 100644
View file @
19472568
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet3d.ops
import
Voxelization
,
DynamicScatter
from
mmdet3d.models
import
builder
@
DETECTORS
.
register_module
()
class
MapTR
(
MVXTwoStageDetector
):
"""MapTR.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def
__init__
(
self
,
use_grid_mask
=
False
,
pts_voxel_layer
=
None
,
pts_voxel_encoder
=
None
,
pts_middle_encoder
=
None
,
pts_fusion_layer
=
None
,
img_backbone
=
None
,
pts_backbone
=
None
,
img_neck
=
None
,
pts_neck
=
None
,
pts_bbox_head
=
None
,
img_roi_head
=
None
,
img_rpn_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
video_test_mode
=
False
,
modality
=
'vision'
,
lidar_encoder
=
None
,
):
super
(
MapTR
,
self
).
__init__
(
pts_voxel_layer
,
pts_voxel_encoder
,
pts_middle_encoder
,
pts_fusion_layer
,
img_backbone
,
pts_backbone
,
img_neck
,
pts_neck
,
pts_bbox_head
,
img_roi_head
,
img_rpn_head
,
train_cfg
,
test_cfg
,
pretrained
)
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
self
.
use_grid_mask
=
use_grid_mask
self
.
fp16_enabled
=
False
# temporal
self
.
video_test_mode
=
video_test_mode
self
.
prev_frame_info
=
{
'prev_bev'
:
None
,
'scene_token'
:
None
,
'prev_pos'
:
0
,
'prev_angle'
:
0
,
}
self
.
modality
=
modality
if
self
.
modality
==
'fusion'
and
lidar_encoder
is
not
None
:
if
lidar_encoder
[
"voxelize"
].
get
(
"max_num_points"
,
-
1
)
>
0
:
voxelize_module
=
Voxelization
(
**
lidar_encoder
[
"voxelize"
])
else
:
voxelize_module
=
DynamicScatter
(
**
lidar_encoder
[
"voxelize"
])
self
.
lidar_modal_extractor
=
nn
.
ModuleDict
(
{
"voxelize"
:
voxelize_module
,
"backbone"
:
builder
.
build_middle_encoder
(
lidar_encoder
[
"backbone"
]),
}
)
self
.
voxelize_reduce
=
lidar_encoder
.
get
(
"voxelize_reduce"
,
True
)
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features of images."""
B
=
img
.
size
(
0
)
if
img
is
not
None
:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if
img
.
dim
()
==
5
and
img
.
size
(
0
)
==
1
:
img
.
squeeze_
()
elif
img
.
dim
()
==
5
and
img
.
size
(
0
)
>
1
:
B
,
N
,
C
,
H
,
W
=
img
.
size
()
img
=
img
.
reshape
(
B
*
N
,
C
,
H
,
W
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
img_feats
=
self
.
img_backbone
(
img
)
if
isinstance
(
img_feats
,
dict
):
img_feats
=
list
(
img_feats
.
values
())
else
:
return
None
if
self
.
with_img_neck
:
img_feats
=
self
.
img_neck
(
img_feats
)
img_feats_reshaped
=
[]
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
if
len_queue
is
not
None
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
else
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
@
auto_fp16
(
apply_to
=
(
'img'
),
out_fp32
=
True
)
def
extract_feat
(
self
,
img
,
img_metas
=
None
,
len_queue
=
None
):
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
return
img_feats
def
forward_pts_train
(
self
,
pts_feats
,
lidar_feat
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
,
prev_bev
=
None
):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs
=
self
.
pts_bbox_head
(
pts_feats
,
lidar_feat
,
img_metas
,
prev_bev
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
pts_bbox_head
.
loss
(
*
loss_inputs
,
img_metas
=
img_metas
)
return
losses
def
forward_dummy
(
self
,
img
):
dummy_metas
=
None
return
self
.
forward_test
(
img
=
img
,
img_metas
=
[[
dummy_metas
]])
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self
.
eval
()
with
torch
.
no_grad
():
prev_bev
=
None
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
for
i
in
range
(
len_queue
):
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats
=
[
each_scale
[:,
i
]
for
each_scale
in
img_feats_list
]
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
,
only_bev
=
True
)
self
.
train
()
return
prev_bev
@
torch
.
no_grad
()
@
force_fp32
()
def
voxelize
(
self
,
points
):
feats
,
coords
,
sizes
=
[],
[],
[]
for
k
,
res
in
enumerate
(
points
):
ret
=
self
.
lidar_modal_extractor
[
"voxelize"
](
res
)
if
len
(
ret
)
==
3
:
# hard voxelize
f
,
c
,
n
=
ret
else
:
assert
len
(
ret
)
==
2
f
,
c
=
ret
n
=
None
feats
.
append
(
f
)
coords
.
append
(
F
.
pad
(
c
,
(
1
,
0
),
mode
=
"constant"
,
value
=
k
))
if
n
is
not
None
:
sizes
.
append
(
n
)
feats
=
torch
.
cat
(
feats
,
dim
=
0
)
coords
=
torch
.
cat
(
coords
,
dim
=
0
)
if
len
(
sizes
)
>
0
:
sizes
=
torch
.
cat
(
sizes
,
dim
=
0
)
if
self
.
voxelize_reduce
:
feats
=
feats
.
sum
(
dim
=
1
,
keepdim
=
False
)
/
sizes
.
type_as
(
feats
).
view
(
-
1
,
1
)
feats
=
feats
.
contiguous
()
return
feats
,
coords
,
sizes
@
auto_fp16
(
apply_to
=
(
'points'
),
out_fp32
=
True
)
def
extract_lidar_feat
(
self
,
points
):
feats
,
coords
,
sizes
=
self
.
voxelize
(
points
)
# voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
batch_size
=
coords
[
-
1
,
0
]
+
1
lidar_feat
=
self
.
lidar_modal_extractor
[
"backbone"
](
feats
,
coords
,
batch_size
,
sizes
=
sizes
)
return
lidar_feat
# @auto_fp16(apply_to=('img', 'points'))
@
force_fp32
(
apply_to
=
(
'img'
,
'points'
,
'prev_bev'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
lidar_feat
=
None
if
self
.
modality
==
'fusion'
:
lidar_feat
=
self
.
extract_lidar_feat
(
points
)
len_queue
=
img
.
size
(
1
)
prev_img
=
img
[:,
:
-
1
,
...]
img
=
img
[:,
-
1
,
...]
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
# prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
# import pdb;pdb.set_trace()
prev_bev
=
self
.
obtain_history_bev
(
prev_img
,
prev_img_metas
)
if
len_queue
>
1
else
None
img_metas
=
[
each
[
len_queue
-
1
]
for
each
in
img_metas
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
lidar_feat
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
img_metas
,
img
=
None
,
points
=
None
,
**
kwargs
):
for
var
,
name
in
[(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
img
=
[
img
]
if
img
is
None
else
img
points
=
[
points
]
if
points
is
None
else
points
if
img_metas
[
0
][
0
][
'scene_token'
]
!=
self
.
prev_frame_info
[
'scene_token'
]:
# the first sample of each scene is truncated
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# update idx
self
.
prev_frame_info
[
'scene_token'
]
=
img_metas
[
0
][
0
][
'scene_token'
]
# do not use temporal information
if
not
self
.
video_test_mode
:
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# Get the delta of ego position and angle between two timestamps.
tmp_pos
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][:
3
])
tmp_angle
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][
-
1
])
if
self
.
prev_frame_info
[
'prev_bev'
]
is
not
None
:
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
-=
self
.
prev_frame_info
[
'prev_pos'
]
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
-=
self
.
prev_frame_info
[
'prev_angle'
]
else
:
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
=
0
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
=
0
new_prev_bev
,
bbox_results
=
self
.
simple_test
(
img_metas
[
0
],
img
[
0
],
points
[
0
],
prev_bev
=
self
.
prev_frame_info
[
'prev_bev'
],
**
kwargs
)
# During inference, we save the BEV features and ego motion of each timestamp.
self
.
prev_frame_info
[
'prev_pos'
]
=
tmp_pos
self
.
prev_frame_info
[
'prev_angle'
]
=
tmp_angle
self
.
prev_frame_info
[
'prev_bev'
]
=
new_prev_bev
return
bbox_results
def
pred2result
(
self
,
bboxes
,
scores
,
labels
,
pts
,
attrs
=
None
):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor): Bounding boxes with shape of (n, 5).
labels (torch.Tensor): Labels with shape of (n, ).
scores (torch.Tensor): Scores with shape of (n, ).
attrs (torch.Tensor, optional): Attributes with shape of (n, ).
\
Defaults to None.
Returns:
dict[str, torch.Tensor]: Bounding box results in cpu mode.
- boxes_3d (torch.Tensor): 3D boxes.
- scores (torch.Tensor): Prediction scores.
- labels_3d (torch.Tensor): Box labels.
- attrs_3d (torch.Tensor, optional): Box attributes.
"""
result_dict
=
dict
(
boxes_3d
=
bboxes
.
to
(
'cpu'
),
scores_3d
=
scores
.
cpu
(),
labels_3d
=
labels
.
cpu
(),
pts_3d
=
pts
.
to
(
'cpu'
))
if
attrs
is
not
None
:
result_dict
[
'attrs_3d'
]
=
attrs
.
cpu
()
return
result_dict
def
simple_test_pts
(
self
,
x
,
lidar_feat
,
img_metas
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function"""
outs
=
self
.
pts_bbox_head
(
x
,
lidar_feat
,
img_metas
,
prev_bev
=
prev_bev
)
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
outs
,
img_metas
,
rescale
=
rescale
)
bbox_results
=
[
self
.
pred2result
(
bboxes
,
scores
,
labels
,
pts
)
for
bboxes
,
scores
,
labels
,
pts
in
bbox_list
]
# import pdb;pdb.set_trace()
return
outs
[
'bev_embed'
],
bbox_results
def
simple_test
(
self
,
img_metas
,
img
=
None
,
points
=
None
,
prev_bev
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton."""
lidar_feat
=
None
if
self
.
modality
==
'fusion'
:
lidar_feat
=
self
.
extract_lidar_feat
(
points
)
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
bbox_list
=
[
dict
()
for
i
in
range
(
len
(
img_metas
))]
new_prev_bev
,
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
lidar_feat
,
img_metas
,
prev_bev
,
rescale
=
rescale
)
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
new_prev_bev
,
bbox_list
@
DETECTORS
.
register_module
()
class
MapTR_fp16
(
MapTR
):
"""
The default version BEVFormer currently can not support FP16.
We provide this version to resolve this issue.
"""
# @auto_fp16(apply_to=('img', 'prev_bev', 'points'))
@
force_fp32
(
apply_to
=
(
'img'
,
'points'
,
'prev_bev'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
prev_bev
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
# import pdb;pdb.set_trace()
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
=
prev_bev
)
losses
.
update
(
losses_pts
)
return
losses
def
val_step
(
self
,
data
,
optimizer
):
"""
In BEVFormer_fp16, we use this `val_step` function to inference the `prev_pev`.
This is not the standard function of `val_step`.
"""
img
=
data
[
'img'
]
img_metas
=
data
[
'img_metas'
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
prev_bev
=
data
.
get
(
'prev_bev'
,
None
)
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
=
prev_bev
,
only_bev
=
True
)
return
prev_bev
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/detectors/maptrv2.py
0 → 100644
View file @
19472568
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet3d.ops
import
Voxelization
,
DynamicScatter
from
mmdet3d.models
import
builder
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
@
DETECTORS
.
register_module
()
class
MapTRv2
(
MVXTwoStageDetector
):
"""MapTR.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def
__init__
(
self
,
use_grid_mask
=
False
,
pts_voxel_layer
=
None
,
pts_voxel_encoder
=
None
,
pts_middle_encoder
=
None
,
pts_fusion_layer
=
None
,
img_backbone
=
None
,
pts_backbone
=
None
,
img_neck
=
None
,
pts_neck
=
None
,
pts_bbox_head
=
None
,
img_roi_head
=
None
,
img_rpn_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
video_test_mode
=
False
,
modality
=
'vision'
,
lidar_encoder
=
None
,
):
super
(
MapTRv2
,
self
).
__init__
(
pts_voxel_layer
,
pts_voxel_encoder
,
pts_middle_encoder
,
pts_fusion_layer
,
img_backbone
,
pts_backbone
,
img_neck
,
pts_neck
,
pts_bbox_head
,
img_roi_head
,
img_rpn_head
,
train_cfg
,
test_cfg
,
pretrained
)
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
self
.
use_grid_mask
=
use_grid_mask
self
.
fp16_enabled
=
False
# temporal
self
.
video_test_mode
=
video_test_mode
self
.
prev_frame_info
=
{
'prev_bev'
:
None
,
'scene_token'
:
None
,
'prev_pos'
:
0
,
'prev_angle'
:
0
,
}
self
.
modality
=
modality
if
self
.
modality
==
'fusion'
and
lidar_encoder
is
not
None
:
if
lidar_encoder
[
"voxelize"
].
get
(
"max_num_points"
,
-
1
)
>
0
:
voxelize_module
=
Voxelization
(
**
lidar_encoder
[
"voxelize"
])
else
:
voxelize_module
=
DynamicScatter
(
**
lidar_encoder
[
"voxelize"
])
self
.
lidar_modal_extractor
=
nn
.
ModuleDict
(
{
"voxelize"
:
voxelize_module
,
"backbone"
:
builder
.
build_middle_encoder
(
lidar_encoder
[
"backbone"
]),
}
)
self
.
voxelize_reduce
=
lidar_encoder
.
get
(
"voxelize_reduce"
,
True
)
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features of images."""
B
=
img
.
size
(
0
)
if
img
is
not
None
:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if
img
.
dim
()
==
5
and
img
.
size
(
0
)
==
1
:
img
.
squeeze_
()
elif
img
.
dim
()
==
5
and
img
.
size
(
0
)
>
1
:
B
,
N
,
C
,
H
,
W
=
img
.
size
()
img
=
img
.
reshape
(
B
*
N
,
C
,
H
,
W
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
img_feats
=
self
.
img_backbone
(
img
)
if
isinstance
(
img_feats
,
dict
):
img_feats
=
list
(
img_feats
.
values
())
else
:
return
None
if
self
.
with_img_neck
:
img_feats
=
self
.
img_neck
(
img_feats
)
img_feats_reshaped
=
[]
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
if
len_queue
is
not
None
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
else
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
@
auto_fp16
(
apply_to
=
(
'img'
),
out_fp32
=
True
)
def
extract_feat
(
self
,
img
,
img_metas
=
None
,
len_queue
=
None
):
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
return
img_feats
def
forward_pts_train
(
self
,
pts_feats
,
lidar_feat
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
,
prev_bev
=
None
,
gt_depth
=
None
,
gt_seg_mask
=
None
,
gt_pv_seg_mask
=
None
,):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs
=
self
.
pts_bbox_head
(
pts_feats
,
lidar_feat
,
img_metas
,
prev_bev
)
depth
=
outs
.
pop
(
'depth'
)
losses
=
dict
()
# calculate depth loss
if
gt_depth
is
not
None
:
loss_depth
=
self
.
pts_bbox_head
.
transformer
.
encoder
.
get_depth_loss
(
gt_depth
,
depth
)
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
loss_depth
=
torch
.
nan_to_num
(
loss_depth
)
losses
.
update
(
loss_depth
=
loss_depth
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
gt_seg_mask
,
gt_pv_seg_mask
,
outs
]
losses_pts
=
self
.
pts_bbox_head
.
loss
(
*
loss_inputs
,
img_metas
=
img_metas
)
losses
.
update
(
losses_pts
)
# import ipdb;ipdb.set_trace()
k_one2many
=
self
.
pts_bbox_head
.
k_one2many
multi_gt_bboxes_3d
=
copy
.
deepcopy
(
gt_bboxes_3d
)
multi_gt_labels_3d
=
copy
.
deepcopy
(
gt_labels_3d
)
for
i
,
(
each_gt_bboxes_3d
,
each_gt_labels_3d
)
in
enumerate
(
zip
(
multi_gt_bboxes_3d
,
multi_gt_labels_3d
)):
each_gt_bboxes_3d
.
instance_list
=
each_gt_bboxes_3d
.
instance_list
*
k_one2many
each_gt_bboxes_3d
.
instance_labels
=
each_gt_bboxes_3d
.
instance_labels
*
k_one2many
multi_gt_labels_3d
[
i
]
=
each_gt_labels_3d
.
repeat
(
k_one2many
)
# import ipdb;ipdb.set_trace()
one2many_outs
=
outs
[
'one2many_outs'
]
loss_one2many_inputs
=
[
multi_gt_bboxes_3d
,
multi_gt_labels_3d
,
gt_seg_mask
,
gt_pv_seg_mask
,
one2many_outs
]
loss_dict_one2many
=
self
.
pts_bbox_head
.
loss
(
*
loss_one2many_inputs
,
img_metas
=
img_metas
)
lambda_one2many
=
self
.
pts_bbox_head
.
lambda_one2many
for
key
,
value
in
loss_dict_one2many
.
items
():
if
key
+
"_one2many"
in
losses
.
keys
():
losses
[
key
+
"_one2many"
]
+=
value
*
lambda_one2many
else
:
losses
[
key
+
"_one2many"
]
=
value
*
lambda_one2many
# import ipdb;ipdb.set_trace()
return
losses
def
forward_dummy
(
self
,
img
):
dummy_metas
=
None
return
self
.
forward_test
(
img
=
img
,
img_metas
=
[[
dummy_metas
]])
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self
.
eval
()
with
torch
.
no_grad
():
prev_bev
=
None
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
for
i
in
range
(
len_queue
):
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats
=
[
each_scale
[:,
i
]
for
each_scale
in
img_feats_list
]
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
,
only_bev
=
True
)
self
.
train
()
return
prev_bev
@
torch
.
no_grad
()
@
force_fp32
()
def
voxelize
(
self
,
points
):
feats
,
coords
,
sizes
=
[],
[],
[]
for
k
,
res
in
enumerate
(
points
):
ret
=
self
.
lidar_modal_extractor
[
"voxelize"
](
res
)
if
len
(
ret
)
==
3
:
# hard voxelize
f
,
c
,
n
=
ret
else
:
assert
len
(
ret
)
==
2
f
,
c
=
ret
n
=
None
feats
.
append
(
f
)
coords
.
append
(
F
.
pad
(
c
,
(
1
,
0
),
mode
=
"constant"
,
value
=
k
))
if
n
is
not
None
:
sizes
.
append
(
n
)
feats
=
torch
.
cat
(
feats
,
dim
=
0
)
coords
=
torch
.
cat
(
coords
,
dim
=
0
)
if
len
(
sizes
)
>
0
:
sizes
=
torch
.
cat
(
sizes
,
dim
=
0
)
if
self
.
voxelize_reduce
:
feats
=
feats
.
sum
(
dim
=
1
,
keepdim
=
False
)
/
sizes
.
type_as
(
feats
).
view
(
-
1
,
1
)
feats
=
feats
.
contiguous
()
return
feats
,
coords
,
sizes
@
auto_fp16
(
apply_to
=
(
'points'
),
out_fp32
=
True
)
def
extract_lidar_feat
(
self
,
points
):
feats
,
coords
,
sizes
=
self
.
voxelize
(
points
)
# voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
batch_size
=
coords
[
-
1
,
0
]
+
1
lidar_feat
=
self
.
lidar_modal_extractor
[
"backbone"
](
feats
,
coords
,
batch_size
,
sizes
=
sizes
)
return
lidar_feat
# @auto_fp16(apply_to=('img', 'points'))
@
force_fp32
(
apply_to
=
(
'img'
,
'points'
,
'prev_bev'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
gt_depth
=
None
,
gt_seg_mask
=
None
,
gt_pv_seg_mask
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
lidar_feat
=
None
if
self
.
modality
==
'fusion'
:
lidar_feat
=
self
.
extract_lidar_feat
(
points
)
len_queue
=
img
.
size
(
1
)
prev_img
=
img
[:,
:
-
1
,
...]
img
=
img
[:,
-
1
,
...]
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
# prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
# import pdb;pdb.set_trace()
prev_bev
=
self
.
obtain_history_bev
(
prev_img
,
prev_img_metas
)
if
len_queue
>
1
else
None
img_metas
=
[
each
[
len_queue
-
1
]
for
each
in
img_metas
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
lidar_feat
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
,
gt_depth
,
gt_seg_mask
,
gt_pv_seg_mask
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
img_metas
,
img
=
None
,
points
=
None
,
**
kwargs
):
for
var
,
name
in
[(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
img
=
[
img
]
if
img
is
None
else
img
points
=
[
points
]
if
points
is
None
else
points
if
img_metas
[
0
][
0
][
'scene_token'
]
!=
self
.
prev_frame_info
[
'scene_token'
]:
# the first sample of each scene is truncated
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# update idx
self
.
prev_frame_info
[
'scene_token'
]
=
img_metas
[
0
][
0
][
'scene_token'
]
# do not use temporal information
if
not
self
.
video_test_mode
:
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# Get the delta of ego position and angle between two timestamps.
tmp_pos
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][:
3
])
tmp_angle
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][
-
1
])
if
self
.
prev_frame_info
[
'prev_bev'
]
is
not
None
:
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
-=
self
.
prev_frame_info
[
'prev_pos'
]
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
-=
self
.
prev_frame_info
[
'prev_angle'
]
else
:
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
=
0
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
=
0
new_prev_bev
,
bbox_results
=
self
.
simple_test
(
img_metas
[
0
],
img
[
0
],
points
[
0
],
prev_bev
=
self
.
prev_frame_info
[
'prev_bev'
],
**
kwargs
)
# During inference, we save the BEV features and ego motion of each timestamp.
self
.
prev_frame_info
[
'prev_pos'
]
=
tmp_pos
self
.
prev_frame_info
[
'prev_angle'
]
=
tmp_angle
self
.
prev_frame_info
[
'prev_bev'
]
=
new_prev_bev
return
bbox_results
def
pred2result
(
self
,
bboxes
,
scores
,
labels
,
pts
,
attrs
=
None
):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor): Bounding boxes with shape of (n, 5).
labels (torch.Tensor): Labels with shape of (n, ).
scores (torch.Tensor): Scores with shape of (n, ).
attrs (torch.Tensor, optional): Attributes with shape of (n, ).
\
Defaults to None.
Returns:
dict[str, torch.Tensor]: Bounding box results in cpu mode.
- boxes_3d (torch.Tensor): 3D boxes.
- scores (torch.Tensor): Prediction scores.
- labels_3d (torch.Tensor): Box labels.
- attrs_3d (torch.Tensor, optional): Box attributes.
"""
result_dict
=
dict
(
boxes_3d
=
bboxes
.
to
(
'cpu'
),
scores_3d
=
scores
.
cpu
(),
labels_3d
=
labels
.
cpu
(),
pts_3d
=
pts
.
to
(
'cpu'
))
if
attrs
is
not
None
:
result_dict
[
'attrs_3d'
]
=
attrs
.
cpu
()
return
result_dict
def
simple_test_pts
(
self
,
x
,
lidar_feat
,
img_metas
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function"""
outs
=
self
.
pts_bbox_head
(
x
,
lidar_feat
,
img_metas
,
prev_bev
=
prev_bev
)
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
outs
,
img_metas
,
rescale
=
rescale
)
bbox_results
=
[
self
.
pred2result
(
bboxes
,
scores
,
labels
,
pts
)
for
bboxes
,
scores
,
labels
,
pts
in
bbox_list
]
# import pdb;pdb.set_trace()
return
outs
[
'bev_embed'
],
bbox_results
def
simple_test
(
self
,
img_metas
,
img
=
None
,
points
=
None
,
prev_bev
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton."""
lidar_feat
=
None
if
self
.
modality
==
'fusion'
:
lidar_feat
=
self
.
extract_lidar_feat
(
points
)
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
bbox_list
=
[
dict
()
for
i
in
range
(
len
(
img_metas
))]
new_prev_bev
,
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
lidar_feat
,
img_metas
,
prev_bev
,
rescale
=
rescale
)
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
new_prev_bev
,
bbox_list
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/losses/__init__.py
0 → 100644
View file @
19472568
from
.map_loss
import
MyChamferDistance
from
.map_loss
import
MyChamferDistanceCost
from
.map_loss
import
OrderedPtsL1Cost
,
PtsL1Cost
from
.map_loss
import
OrderedPtsL1Loss
,
PtsL1Loss
from
.map_loss
import
OrderedPtsSmoothL1Cost
,
OrderedPtsL1Loss
from
.map_loss
import
PtsDirCosLoss
from
.simple_loss
import
SimpleLoss
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/losses/map_loss.py
0 → 100644
View file @
19472568
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch
import
nn
as
nn
from
torch.nn.functional
import
l1_loss
,
mse_loss
,
smooth_l1_loss
from
mmdet.models.builder
import
LOSSES
from
mmdet.models
import
weighted_loss
import
mmcv
import
torch.nn.functional
as
F
from
mmdet.core.bbox.match_costs.builder
import
MATCH_COST
import
functools
def
reduce_loss
(
loss
,
reduction
):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum
=
F
.
_Reduction
.
get_enum
(
reduction
)
# none: 0, elementwise_mean:1, sum: 2
if
reduction_enum
==
0
:
return
loss
elif
reduction_enum
==
1
:
return
loss
.
mean
()
elif
reduction_enum
==
2
:
return
loss
.
sum
()
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
def
custom_weight_dir_reduce_loss
(
loss
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): num_sample, num_dir
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if
weight
is
not
None
:
loss
=
loss
*
weight
# if avg_factor is not specified, just reduce the loss
if
avg_factor
is
None
:
raise
ValueError
(
'avg_factor should not be none for OrderedPtsL1Loss'
)
# loss = reduce_loss(loss, reduction)
else
:
# if reduction is mean, then average the loss by avg_factor
if
reduction
==
'mean'
:
# import pdb;pdb.set_trace()
# loss = loss.permute(1,0,2,3).contiguous()
loss
=
loss
.
sum
()
loss
=
loss
/
avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif
reduction
!=
'none'
:
raise
ValueError
(
'avg_factor can not be used with reduction="sum"'
)
return
loss
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
def
custom_weight_reduce_loss
(
loss
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): num_sample, num_order, num_pts, num_coords
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if
weight
is
not
None
:
loss
=
loss
*
weight
# if avg_factor is not specified, just reduce the loss
if
avg_factor
is
None
:
raise
ValueError
(
'avg_factor should not be none for OrderedPtsL1Loss'
)
# loss = reduce_loss(loss, reduction)
else
:
# if reduction is mean, then average the loss by avg_factor
if
reduction
==
'mean'
:
# import pdb;pdb.set_trace()
loss
=
loss
.
permute
(
1
,
0
,
2
,
3
).
contiguous
()
loss
=
loss
.
sum
((
1
,
2
,
3
))
loss
=
loss
/
avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif
reduction
!=
'none'
:
raise
ValueError
(
'avg_factor can not be used with reduction="sum"'
)
return
loss
def
custom_weighted_loss
(
loss_func
):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@
functools
.
wraps
(
loss_func
)
def
wrapper
(
pred
,
target
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
**
kwargs
):
# get element-wise loss
loss
=
loss_func
(
pred
,
target
,
**
kwargs
)
loss
=
custom_weight_reduce_loss
(
loss
,
weight
,
reduction
,
avg_factor
)
return
loss
return
wrapper
def
custom_weighted_dir_loss
(
loss_func
):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@
functools
.
wraps
(
loss_func
)
def
wrapper
(
pred
,
target
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
**
kwargs
):
# get element-wise loss
loss
=
loss_func
(
pred
,
target
,
**
kwargs
)
loss
=
custom_weight_dir_reduce_loss
(
loss
,
weight
,
reduction
,
avg_factor
)
return
loss
return
wrapper
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
custom_weighted_loss
def
ordered_pts_smooth_l1_loss
(
pred
,
target
):
"""L1 loss.
Args:
pred (torch.Tensor): shape [num_samples, num_pts, num_coords]
target (torch.Tensor): shape [num_samples, num_order, num_pts, num_coords]
Returns:
torch.Tensor: Calculated loss
"""
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
pred
=
pred
.
unsqueeze
(
1
).
repeat
(
1
,
target
.
size
(
1
),
1
,
1
)
assert
pred
.
size
()
==
target
.
size
()
loss
=
smooth_l1_loss
(
pred
,
target
,
reduction
=
'none'
)
# import pdb;pdb.set_trace()
return
loss
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
weighted_loss
def
pts_l1_loss
(
pred
,
target
):
"""L1 loss.
Args:
pred (torch.Tensor): shape [num_samples, num_pts, num_coords]
target (torch.Tensor): shape [num_samples, num_pts, num_coords]
Returns:
torch.Tensor: Calculated loss
"""
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
assert
pred
.
size
()
==
target
.
size
()
loss
=
torch
.
abs
(
pred
-
target
)
return
loss
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
custom_weighted_loss
def
ordered_pts_l1_loss
(
pred
,
target
):
"""L1 loss.
Args:
pred (torch.Tensor): shape [num_samples, num_pts, num_coords]
target (torch.Tensor): shape [num_samples, num_order, num_pts, num_coords]
Returns:
torch.Tensor: Calculated loss
"""
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
pred
=
pred
.
unsqueeze
(
1
).
repeat
(
1
,
target
.
size
(
1
),
1
,
1
)
assert
pred
.
size
()
==
target
.
size
()
loss
=
torch
.
abs
(
pred
-
target
)
return
loss
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
@
custom_weighted_dir_loss
def
pts_dir_cos_loss
(
pred
,
target
):
""" Dir cosine similiarity loss
pred (torch.Tensor): shape [num_samples, num_dir, num_coords]
target (torch.Tensor): shape [num_samples, num_dir, num_coords]
"""
if
target
.
numel
()
==
0
:
return
pred
.
sum
()
*
0
# import pdb;pdb.set_trace()
num_samples
,
num_dir
,
num_coords
=
pred
.
shape
loss_func
=
torch
.
nn
.
CosineEmbeddingLoss
(
reduction
=
'none'
)
tgt_param
=
target
.
new_ones
((
num_samples
,
num_dir
))
tgt_param
=
tgt_param
.
flatten
(
0
)
loss
=
loss_func
(
pred
.
flatten
(
0
,
1
),
target
.
flatten
(
0
,
1
),
tgt_param
)
loss
=
loss
.
view
(
num_samples
,
num_dir
)
return
loss
@
LOSSES
.
register_module
()
class
OrderedPtsSmoothL1Loss
(
nn
.
Module
):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
OrderedPtsSmoothL1Loss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
# import pdb;pdb.set_trace()
loss_bbox
=
self
.
loss_weight
*
ordered_pts_smooth_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss_bbox
@
LOSSES
.
register_module
()
class
PtsDirCosLoss
(
nn
.
Module
):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
PtsDirCosLoss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
# import pdb;pdb.set_trace()
loss_dir
=
self
.
loss_weight
*
pts_dir_cos_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss_dir
@
LOSSES
.
register_module
()
class
PtsL1Loss
(
nn
.
Module
):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
PtsL1Loss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
# import pdb;pdb.set_trace()
loss_bbox
=
self
.
loss_weight
*
pts_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss_bbox
@
LOSSES
.
register_module
()
class
OrderedPtsL1Loss
(
nn
.
Module
):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
):
super
(
OrderedPtsL1Loss
,
self
).
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
# import pdb;pdb.set_trace()
loss_bbox
=
self
.
loss_weight
*
ordered_pts_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss_bbox
@
MATCH_COST
.
register_module
()
class
OrderedPtsSmoothL1Cost
(
object
):
"""OrderedPtsL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def
__init__
(
self
,
weight
=
1.
):
self
.
weight
=
weight
def
__call__
(
self
,
bbox_pred
,
gt_bboxes
):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(x, y), which are all in range [0, 1]. Shape
[num_query, num_pts, 2].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x,y).
Shape [num_gt, num_ordered, num_pts, 2].
Returns:
torch.Tensor: bbox_cost value with weight
"""
num_gts
,
num_orders
,
num_pts
,
num_coords
=
gt_bboxes
.
shape
# import pdb;pdb.set_trace()
bbox_pred
=
bbox_pred
.
view
(
bbox_pred
.
size
(
0
),
-
1
).
unsqueeze
(
1
).
repeat
(
1
,
num_gts
*
num_orders
,
1
)
gt_bboxes
=
gt_bboxes
.
flatten
(
2
).
view
(
num_gts
*
num_orders
,
-
1
).
unsqueeze
(
0
).
repeat
(
bbox_pred
.
size
(
0
),
1
,
1
)
# import pdb;pdb.set_trace()
bbox_cost
=
smooth_l1_loss
(
bbox_pred
,
gt_bboxes
,
reduction
=
'none'
).
sum
(
-
1
)
# bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return
bbox_cost
*
self
.
weight
@
MATCH_COST
.
register_module
()
class
PtsL1Cost
(
object
):
"""OrderedPtsL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def
__init__
(
self
,
weight
=
1.
):
self
.
weight
=
weight
def
__call__
(
self
,
bbox_pred
,
gt_bboxes
):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(x, y), which are all in range [0, 1]. Shape
[num_query, num_pts, 2].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x,y).
Shape [num_gt, num_ordered, num_pts, 2].
Returns:
torch.Tensor: bbox_cost value with weight
"""
num_gts
,
num_pts
,
num_coords
=
gt_bboxes
.
shape
# import pdb;pdb.set_trace()
bbox_pred
=
bbox_pred
.
view
(
bbox_pred
.
size
(
0
),
-
1
)
gt_bboxes
=
gt_bboxes
.
view
(
num_gts
,
-
1
)
bbox_cost
=
torch
.
cdist
(
bbox_pred
,
gt_bboxes
,
p
=
1
)
return
bbox_cost
*
self
.
weight
@
MATCH_COST
.
register_module
()
class
OrderedPtsL1Cost
(
object
):
"""OrderedPtsL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def
__init__
(
self
,
weight
=
1.
):
self
.
weight
=
weight
def
__call__
(
self
,
bbox_pred
,
gt_bboxes
):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(x, y), which are all in range [0, 1]. Shape
[num_query, num_pts, 2].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x,y).
Shape [num_gt, num_ordered, num_pts, 2].
Returns:
torch.Tensor: bbox_cost value with weight
"""
num_gts
,
num_orders
,
num_pts
,
num_coords
=
gt_bboxes
.
shape
# import pdb;pdb.set_trace()
bbox_pred
=
bbox_pred
.
view
(
bbox_pred
.
size
(
0
),
-
1
)
gt_bboxes
=
gt_bboxes
.
flatten
(
2
).
view
(
num_gts
*
num_orders
,
-
1
)
#bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
bbox_cost
=
(
bbox_pred
[:,
None
,:]
-
gt_bboxes
[
None
,:,:]).
abs
().
sum
(
dim
=-
1
)
return
bbox_cost
*
self
.
weight
@
MATCH_COST
.
register_module
()
class
MyChamferDistanceCost
:
def
__init__
(
self
,
loss_src_weight
=
1.
,
loss_dst_weight
=
1.
):
# assert mode in ['smooth_l1', 'l1', 'l2']
# self.mode = mode
self
.
loss_src_weight
=
loss_src_weight
self
.
loss_dst_weight
=
loss_dst_weight
def
__call__
(
self
,
src
,
dst
,
src_weight
=
1.0
,
dst_weight
=
1.0
,):
"""
pred_pts (Tensor): normed coordinate(x,y), shape (num_q, num_pts_M, 2)
gt_pts (Tensor): normed coordinate(x,y), shape (num_gt, num_pts_N, 2)
"""
# criterion_mode = self.mode
# if criterion_mode == 'smooth_l1':
# criterion = smooth_l1_loss
# elif criterion_mode == 'l1':
# criterion = l1_loss
# elif criterion_mode == 'l2':
# criterion = mse_loss
# else:
# raise NotImplementedError
# import pdb;pdb.set_trace()
src_expand
=
src
.
unsqueeze
(
1
).
repeat
(
1
,
dst
.
shape
[
0
],
1
,
1
)
dst_expand
=
dst
.
unsqueeze
(
0
).
repeat
(
src
.
shape
[
0
],
1
,
1
,
1
)
# src_expand = src.unsqueeze(2).unsqueeze(1).repeat(1,dst.shape[0], 1, dst.shape[1], 1)
# dst_expand = dst.unsqueeze(1).unsqueeze(0).repeat(src.shape[0],1, src.shape[1], 1, 1)
distance
=
torch
.
cdist
(
src_expand
,
dst_expand
)
src2dst_distance
=
torch
.
min
(
distance
,
dim
=
3
)[
0
]
# (num_q, num_gt, num_pts_N)
dst2src_distance
=
torch
.
min
(
distance
,
dim
=
2
)[
0
]
# (num_q, num_gt, num_pts_M)
loss_src
=
(
src2dst_distance
*
src_weight
).
mean
(
-
1
)
loss_dst
=
(
dst2src_distance
*
dst_weight
).
mean
(
-
1
)
loss
=
loss_src
*
self
.
loss_src_weight
+
loss_dst
*
self
.
loss_dst_weight
return
loss
@
mmcv
.
jit
(
derivate
=
True
,
coderize
=
True
)
def
chamfer_distance
(
src
,
dst
,
src_weight
=
1.0
,
dst_weight
=
1.0
,
# criterion_mode='l1',
reduction
=
'mean'
,
avg_factor
=
None
):
"""Calculate Chamfer Distance of two sets.
Args:
src (torch.Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
dst (torch.Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (torch.Tensor or float): Weight of source loss.
dst_weight (torch.Tensor or float): Weight of destination loss.
criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Returns:
tuple: Source and Destination loss with the corresponding indices.
- loss_src (torch.Tensor): The min distance
\
from source to destination.
- loss_dst (torch.Tensor): The min distance
\
from destination to source.
- indices1 (torch.Tensor): Index the min distance point
\
for each point in source to destination.
- indices2 (torch.Tensor): Index the min distance point
\
for each point in destination to source.
"""
# if criterion_mode == 'smooth_l1':
# criterion = smooth_l1_loss
# elif criterion_mode == 'l1':
# criterion = l1_loss
# elif criterion_mode == 'l2':
# criterion = mse_loss
# else:
# raise NotImplementedError
# src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
# dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)
# import pdb;pdb.set_trace()
distance
=
torch
.
cdist
(
src
,
dst
)
src2dst_distance
,
indices1
=
torch
.
min
(
distance
,
dim
=
2
)
# (B,N)
dst2src_distance
,
indices2
=
torch
.
min
(
distance
,
dim
=
1
)
# (B,M)
# import pdb;pdb.set_trace()
#TODO this may be wrong for misaligned src_weight, now[N,fixed_num]
# should be [N], then view
loss_src
=
(
src2dst_distance
*
src_weight
)
loss_dst
=
(
dst2src_distance
*
dst_weight
)
if
avg_factor
is
None
:
reduction_enum
=
F
.
_Reduction
.
get_enum
(
reduction
)
if
reduction_enum
==
0
:
raise
ValueError
(
'MyCDLoss can not be used with reduction=`none`'
)
elif
reduction_enum
==
1
:
loss_src
=
loss_src
.
mean
(
-
1
).
mean
()
loss_dst
=
loss_dst
.
mean
(
-
1
).
mean
()
elif
reduction_enum
==
2
:
loss_src
=
loss_src
.
mean
(
-
1
).
sum
()
loss_dst
=
loss_dst
.
mean
(
-
1
).
sum
()
else
:
raise
NotImplementedError
else
:
if
reduction
==
'mean'
:
eps
=
torch
.
finfo
(
torch
.
float32
).
eps
loss_src
=
loss_src
.
mean
(
-
1
).
sum
()
/
(
avg_factor
+
eps
)
loss_dst
=
loss_dst
.
mean
(
-
1
).
sum
()
/
(
avg_factor
+
eps
)
elif
reduction
!=
'none'
:
raise
ValueError
(
'avg_factor can not be used with reduction="sum"'
)
return
loss_src
,
loss_dst
,
indices1
,
indices2
@
LOSSES
.
register_module
()
class
MyChamferDistance
(
nn
.
Module
):
"""Calculate Chamfer Distance of two sets.
Args:
mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_src_weight (float): Weight of loss_source.
loss_dst_weight (float): Weight of loss_target.
"""
def
__init__
(
self
,
# mode='l1',
reduction
=
'mean'
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
):
super
(
MyChamferDistance
,
self
).
__init__
()
# assert mode in ['smooth_l1', 'l1', 'l2']
assert
reduction
in
[
'none'
,
'sum'
,
'mean'
]
# self.mode = mode
self
.
reduction
=
reduction
self
.
loss_src_weight
=
loss_src_weight
self
.
loss_dst_weight
=
loss_dst_weight
def
forward
(
self
,
source
,
target
,
src_weight
=
1.0
,
dst_weight
=
1.0
,
avg_factor
=
None
,
reduction_override
=
None
,
return_indices
=
False
,
**
kwargs
):
"""Forward function of loss calculation.
Args:
source (torch.Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
target (torch.Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (torch.Tensor | float, optional):
Weight of source loss. Defaults to 1.0.
dst_weight (torch.Tensor | float, optional):
Weight of destination loss. Defaults to 1.0.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
return_indices (bool, optional): Whether to return indices.
Defaults to False.
Returns:
tuple[torch.Tensor]: If ``return_indices=True``, return losses of
\
source and target with their corresponding indices in the
\
order of ``(loss_source, loss_target, indices1, indices2)``.
\
If ``return_indices=False``, return
\
``(loss_source, loss_target)``.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss_source
,
loss_target
,
indices1
,
indices2
=
chamfer_distance
(
source
,
target
,
src_weight
,
dst_weight
,
reduction
,
avg_factor
=
avg_factor
)
loss_source
*=
self
.
loss_src_weight
loss_target
*=
self
.
loss_dst_weight
loss_pts
=
loss_source
+
loss_target
if
return_indices
:
return
loss_pts
,
indices1
,
indices2
else
:
return
loss_pts
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/losses/simple_loss.py
0 → 100644
View file @
19472568
import
torch
import
torch.nn
as
nn
from
mmdet.models.builder
import
LOSSES
import
torch.nn.functional
as
F
from
mmdet.models.losses
import
FocalLoss
,
weight_reduce_loss
def
py_sigmoid_focal_loss
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid
=
pred
.
sigmoid
()
target
=
target
.
type_as
(
pred
)
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
focal_weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
pt
.
pow
(
gamma
)
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
target
,
reduction
=
'none'
)
*
focal_weight
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
,
avg_factor
)
return
loss
@
LOSSES
.
register_module
(
force
=
True
)
class
SimpleLoss_v1
(
nn
.
Module
):
def
__init__
(
self
,
pos_weight
,
loss_weight
):
super
(
SimpleLoss_v1
,
self
).
__init__
()
# self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_weight]))
# self.loss_fn = torch.nn.CrossEntroyLoss(reduction="none")
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
ypred
,
ytgt
):
bs
,
pred_class_num
,
bev_h
,
bev_w
=
ypred
.
shape
ypred
=
ypred
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
bs
*
bev_h
*
bev_w
,
pred_class_num
).
contiguous
()
ytgt
=
ytgt
.
view
(
-
1
)
ytgt
=
F
.
one_hot
(
ytgt
.
long
(),
num_classes
=
pred_class_num
+
1
).
view
(
-
1
,
pred_class_num
+
1
)[:,
1
:]
fg_mask
=
torch
.
max
(
ytgt
,
dim
=
1
).
values
>
0.0
ypred
=
ypred
[
fg_mask
]
ytgt
=
ytgt
[
fg_mask
]
loss
=
F
.
binary_cross_entropy_with_logits
(
ypred
,
ytgt
.
float
(),
reduction
=
'none'
,).
sum
()
/
max
(
1.0
,
fg_mask
.
sum
())
return
loss
*
self
.
loss_weight
@
LOSSES
.
register_module
()
class
SimpleLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
pos_weight
,
loss_weight
):
super
(
SimpleLoss
,
self
).
__init__
()
self
.
loss_fn
=
torch
.
nn
.
BCEWithLogitsLoss
(
pos_weight
=
torch
.
Tensor
([
pos_weight
]))
self
.
loss_weight
=
loss_weight
def
forward
(
self
,
ypred
,
ytgt
):
# import ipdb;ipdb.set_trace()
loss
=
self
.
loss_fn
(
ypred
,
ytgt
)
return
loss
*
self
.
loss_weight
@
LOSSES
.
register_module
()
class
MaskFocalLoss
(
FocalLoss
):
def
__init__
(
self
,
**
kwargs
):
super
(
MaskFocalLoss
,
self
).
__init__
(
**
kwargs
)
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
if
not
self
.
use_sigmoid
:
raise
NotImplementedError
num_classes
=
pred
.
size
(
1
)
loss
=
0
for
index
in
range
(
num_classes
):
loss
+=
self
.
loss_weight
*
py_sigmoid_focal_loss
(
pred
[:,
index
],
target
[:,
index
],
weight
,
gamma
=
self
.
gamma
,
alpha
=
self
.
alpha
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
# import ipdb; ipdb.set_trace()
loss
/=
num_classes
return
loss
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/__init__.py
0 → 100644
View file @
19472568
from
.transformer
import
MapTRPerceptionTransformer
from
.decoder
import
MapTRDecoder
,
DecoupledDetrTransformerDecoderLayer
from
.geometry_kernel_attention
import
GeometrySptialCrossAttention
,
GeometryKernelAttention
from
.builder
import
build_fuser
from
.encoder
import
LSSTransform
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/builder.py
0 → 100644
View file @
19472568
import
torch.nn
as
nn
from
mmcv.utils
import
Registry
,
build_from_cfg
FUSERS
=
Registry
(
"fusers"
)
def
build_fuser
(
cfg
):
return
FUSERS
.
build
(
cfg
)
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/decoder.py
0 → 100644
View file @
19472568
import
torch
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
POSITIONAL_ENCODING
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
,
BaseTransformerLayer
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
MapTRDecoder
(
TransformerLayerSequence
):
"""Implements the decoder in DETR3D transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
**
kwargs
):
super
(
MapTRDecoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
fp16_enabled
=
False
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
forward
(
self
,
query
,
*
args
,
reference_points
=
None
,
reg_branches
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
query
intermediate
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
reference_points
[...,
:
2
].
unsqueeze
(
2
)
# BS NUM_QUERY NUM_LEVEL 2
output
=
layer
(
output
,
*
args
,
reference_points
=
reference_points_input
,
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
# assert reference_points.shape[-1] == 2
new_reference_points
=
torch
.
zeros_like
(
reference_points
)
new_reference_points
=
tmp
+
inverse_sigmoid
(
reference_points
)
# new_reference_points[..., 2:3] = tmp[
# ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
new_reference_points
=
new_reference_points
.
sigmoid
()
reference_points
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
return
output
,
reference_points
@
TRANSFORMER_LAYER
.
register_module
()
class
DecoupledDetrTransformerDecoderLayer
(
BaseTransformerLayer
):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def
__init__
(
self
,
attn_cfgs
,
feedforward_channels
,
num_vec
=
50
,
num_pts_per_vec
=
20
,
ffn_dropout
=
0.0
,
operation_order
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'LN'
),
ffn_num_fcs
=
2
,
**
kwargs
):
super
(
DecoupledDetrTransformerDecoderLayer
,
self
).
__init__
(
attn_cfgs
=
attn_cfgs
,
feedforward_channels
=
feedforward_channels
,
ffn_dropout
=
ffn_dropout
,
operation_order
=
operation_order
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
ffn_num_fcs
=
ffn_num_fcs
,
**
kwargs
)
assert
len
(
operation_order
)
==
8
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
self
.
num_vec
=
num_vec
self
.
num_pts_per_vec
=
num_pts_per_vec
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
#
num_vec
=
kwargs
[
'num_vec'
]
num_pts_per_vec
=
kwargs
[
'num_pts_per_vec'
]
for
layer
in
self
.
operation_order
:
if
layer
==
'self_attn'
:
# import ipdb;ipdb.set_trace()
if
attn_index
==
0
:
n_pts
,
n_batch
,
n_dim
=
query
.
shape
query
=
query
.
view
(
num_vec
,
num_pts_per_vec
,
n_batch
,
n_dim
).
flatten
(
1
,
2
)
query_pos
=
query_pos
.
view
(
num_vec
,
num_pts_per_vec
,
n_batch
,
n_dim
).
flatten
(
1
,
2
)
temp_key
=
temp_value
=
query
query
=
self
.
attentions
[
attn_index
](
query
,
temp_key
,
temp_value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
query_pos
,
attn_mask
=
kwargs
[
'self_attn_mask'
],
key_padding_mask
=
query_key_padding_mask
,
**
kwargs
)
# import ipdb;ipdb.set_trace()
query
=
query
.
view
(
num_vec
,
num_pts_per_vec
,
n_batch
,
n_dim
).
flatten
(
0
,
1
)
query_pos
=
query_pos
.
view
(
num_vec
,
num_pts_per_vec
,
n_batch
,
n_dim
).
flatten
(
0
,
1
)
attn_index
+=
1
identity
=
query
else
:
# import ipdb;ipdb.set_trace()
n_pts
,
n_batch
,
n_dim
=
query
.
shape
query
=
query
.
view
(
num_vec
,
num_pts_per_vec
,
n_batch
,
n_dim
).
permute
(
1
,
0
,
2
,
3
).
contiguous
().
flatten
(
1
,
2
)
query_pos
=
query_pos
.
view
(
num_vec
,
num_pts_per_vec
,
n_batch
,
n_dim
).
permute
(
1
,
0
,
2
,
3
).
contiguous
().
flatten
(
1
,
2
)
temp_key
=
temp_value
=
query
query
=
self
.
attentions
[
attn_index
](
query
,
temp_key
,
temp_value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
query_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
**
kwargs
)
# import ipdb;ipdb.set_trace()
query
=
query
.
view
(
num_pts_per_vec
,
num_vec
,
n_batch
,
n_dim
).
permute
(
1
,
0
,
2
,
3
).
contiguous
().
flatten
(
0
,
1
)
query_pos
=
query_pos
.
view
(
num_pts_per_vec
,
num_vec
,
n_batch
,
n_dim
).
permute
(
1
,
0
,
2
,
3
).
contiguous
().
flatten
(
0
,
1
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/encoder.py
0 → 100644
View file @
19472568
import
torch
import
numpy
as
np
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
import
torch.nn
as
nn
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmdet3d.ops
import
bev_pool
from
mmdet3d.ops.bev_pool_v2.bev_pool
import
bev_pool_v2
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
torch.cuda.amp.autocast_mode
import
autocast
from
mmcv.cnn
import
build_conv_layer
from
mmdet.models.backbones.resnet
import
BasicBlock
,
Bottleneck
import
torch.nn.functional
as
F
from
projects.mmdet3d_plugin.bevformer.modules.encoder
import
BEVFormerEncoder
torch
.
set_float32_matmul_precision
(
'high'
)
def
gen_dx_bx
(
xbound
,
ybound
,
zbound
):
dx
=
torch
.
Tensor
([
row
[
2
]
for
row
in
[
xbound
,
ybound
,
zbound
]])
bx
=
torch
.
Tensor
([
row
[
0
]
+
row
[
2
]
/
2.0
for
row
in
[
xbound
,
ybound
,
zbound
]])
nx
=
torch
.
Tensor
(
[
int
((
row
[
1
]
-
row
[
0
])
/
row
[
2
])
for
row
in
[
xbound
,
ybound
,
zbound
]]
)
return
dx
,
bx
,
nx
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BaseTransform
(
BaseModule
):
def
__init__
(
self
,
in_channels
,
out_channels
,
feat_down_sample
,
pc_range
,
voxel_size
,
dbound
,
):
super
(
BaseTransform
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
feat_down_sample
=
feat_down_sample
# self.image_size = image_size
# self.feature_size = feature_size
self
.
xbound
=
[
pc_range
[
0
],
pc_range
[
3
],
voxel_size
[
0
]]
self
.
ybound
=
[
pc_range
[
1
],
pc_range
[
4
],
voxel_size
[
1
]]
self
.
zbound
=
[
pc_range
[
2
],
pc_range
[
5
],
voxel_size
[
2
]]
self
.
dbound
=
dbound
dx
,
bx
,
nx
=
gen_dx_bx
(
self
.
xbound
,
self
.
ybound
,
self
.
zbound
)
self
.
dx
=
nn
.
Parameter
(
dx
,
requires_grad
=
False
)
self
.
bx
=
nn
.
Parameter
(
bx
,
requires_grad
=
False
)
self
.
nx
=
nn
.
Parameter
(
nx
,
requires_grad
=
False
)
self
.
C
=
out_channels
self
.
frustum
=
None
self
.
D
=
int
((
dbound
[
1
]
-
dbound
[
0
])
/
dbound
[
2
])
# self.frustum = self.create_frustum()
# self.D = self.frustum.shape[0]
self
.
fp16_enabled
=
False
@
force_fp32
()
def
create_frustum
(
self
,
fH
,
fW
,
img_metas
):
# iH, iW = self.image_size
# fH, fW = self.feature_size
iH
=
img_metas
[
0
][
'img_shape'
][
0
][
0
]
iW
=
img_metas
[
0
][
'img_shape'
][
0
][
1
]
assert
iH
//
self
.
feat_down_sample
==
fH
# import pdb;pdb.set_trace()
ds
=
(
torch
.
arange
(
*
self
.
dbound
,
dtype
=
torch
.
float
)
.
view
(
-
1
,
1
,
1
)
.
expand
(
-
1
,
fH
,
fW
)
)
D
,
_
,
_
=
ds
.
shape
xs
=
(
torch
.
linspace
(
0
,
iW
-
1
,
fW
,
dtype
=
torch
.
float
)
.
view
(
1
,
1
,
fW
)
.
expand
(
D
,
fH
,
fW
)
)
ys
=
(
torch
.
linspace
(
0
,
iH
-
1
,
fH
,
dtype
=
torch
.
float
)
.
view
(
1
,
fH
,
1
)
.
expand
(
D
,
fH
,
fW
)
)
frustum
=
torch
.
stack
((
xs
,
ys
,
ds
),
-
1
)
# return nn.Parameter(frustum, requires_grad=False)
return
frustum
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
matmul_1
(
self
,
x
,
y
,
trans
):
B
,
N
,
_
=
trans
.
shape
points
=
torch
.
matmul
(
x
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
),
y
.
unsqueeze
(
-
1
))
points
=
torch
.
cat
(
(
points
[:,
:,
:,
:,
:,
:
2
]
*
points
[:,
:,
:,
:,
:,
2
:
3
],
points
[:,
:,
:,
:,
:,
2
:
3
],
),
5
,
)
return
points
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
matmul_2
(
self
,
x
,
y
,
trans
,
lidar2ego_trans
,
points
):
B
,
N
,
_
=
trans
.
shape
combine
=
torch
.
matmul
(
x
,
y
)
points
=
torch
.
matmul
(
combine
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
),
points
).
squeeze
(
-
1
)
points
+=
trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
points
-=
lidar2ego_trans
.
view
(
B
,
1
,
1
,
1
,
1
,
3
)
return
points
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
matmul_3
(
self
,
x
,
y
,
trans
):
B
,
N
,
_
=
trans
.
shape
points
=
torch
.
matmul
(
x
.
view
(
B
,
1
,
1
,
1
,
1
,
3
,
3
),
y
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
return
points
@
force_fp32
()
def
get_geometry_v1
(
self
,
fH
,
fW
,
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
img_metas
,
**
kwargs
,
):
B
,
N
,
_
=
trans
.
shape
device
=
trans
.
device
if
self
.
frustum
==
None
:
self
.
frustum
=
self
.
create_frustum
(
fH
,
fW
,
img_metas
)
self
.
frustum
=
self
.
frustum
.
to
(
device
)
# self.D = self.frustum.shape[0]
# undo post-transformation
# B x N x D x H x W x 3
points
=
self
.
frustum
-
post_trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
post_rots
=
torch
.
inverse
(
post_rots
)
points
=
self
.
matmul_1
(
post_rots
,
points
,
trans
)
# points = torch.matmul(post_rots.view(B, N, 1, 1, 1, 3, 3), points.unsqueeze(-1))
# # cam_to_ego
# points = torch.cat(
# (
# points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
# points[:, :, :, :, :, 2:3],
# ),
# 5,
# )
intrins
=
torch
.
inverse
(
intrins
)
# combine = torch.matmul(rots, intrins)
# points = torch.matmul(combine.view(B, N, 1, 1, 1, 3, 3), points).squeeze(-1)
# points += trans.view(B, N, 1, 1, 1, 3)
# # ego_to_lidar
# points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3)
points
=
self
.
matmul_2
(
rots
,
intrins
,
trans
,
lidar2ego_trans
,
points
)
lidar2ego_rots
=
torch
.
inverse
(
lidar2ego_rots
)
points
=
self
.
matmul_3
(
lidar2ego_rots
,
points
,
trans
)
if
"extra_rots"
in
kwargs
:
extra_rots
=
kwargs
[
"extra_rots"
]
points
=
torch
.
matmul
(
extra_rots
.
view
(
B
,
1
,
1
,
1
,
1
,
3
,
3
).
repeat
(
1
,
N
,
1
,
1
,
1
,
1
,
1
),
points
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
if
"extra_trans"
in
kwargs
:
extra_trans
=
kwargs
[
"extra_trans"
]
points
+=
extra_trans
.
view
(
B
,
1
,
1
,
1
,
1
,
3
).
repeat
(
1
,
N
,
1
,
1
,
1
,
1
)
return
points
@
force_fp32
()
def
get_geometry
(
self
,
fH
,
fW
,
lidar2img
,
img_metas
,
):
B
,
N
,
_
,
_
=
lidar2img
.
shape
device
=
lidar2img
.
device
# import pdb;pdb.set_trace()
if
self
.
frustum
==
None
:
self
.
frustum
=
self
.
create_frustum
(
fH
,
fW
,
img_metas
)
self
.
frustum
=
self
.
frustum
.
to
(
device
)
# self.D = self.frustum.shape[0]
points
=
self
.
frustum
.
view
(
1
,
1
,
self
.
D
,
fH
,
fW
,
3
)
\
.
repeat
(
B
,
N
,
1
,
1
,
1
,
1
)
lidar2img
=
lidar2img
.
view
(
B
,
N
,
1
,
1
,
1
,
4
,
4
)
# img2lidar = torch.inverse(lidar2img)
points
=
torch
.
cat
(
(
points
,
torch
.
ones_like
(
points
[...,
:
1
])),
-
1
)
points
=
torch
.
linalg
.
solve
(
lidar2img
.
to
(
torch
.
float32
),
points
.
unsqueeze
(
-
1
).
to
(
torch
.
float32
)).
squeeze
(
-
1
)
# points = torch.matmul(img2lidar.to(torch.float32),
# points.unsqueeze(-1).to(torch.float32)).squeeze(-1)
# import pdb;pdb.set_trace()
eps
=
1e-5
points
=
points
[...,
0
:
3
]
/
torch
.
maximum
(
points
[...,
3
:
4
],
torch
.
ones_like
(
points
[...,
3
:
4
])
*
eps
)
return
points
def
get_cam_feats
(
self
,
x
):
raise
NotImplementedError
def
get_mlp_input
(
self
,
sensor2ego
,
intrin
,
post_rot
,
post_tran
,
bda
):
raise
NotImplementedError
@
force_fp32
()
def
bev_pool
(
self
,
geom_feats
,
x
):
B
,
N
,
D
,
H
,
W
,
C
=
x
.
shape
Nprime
=
B
*
N
*
D
*
H
*
W
# flatten x
x
=
x
.
reshape
(
Nprime
,
C
)
# flatten indices
geom_feats
=
((
geom_feats
-
(
self
.
bx
-
self
.
dx
/
2.0
))
/
self
.
dx
).
long
()
geom_feats
=
geom_feats
.
view
(
Nprime
,
3
)
batch_ix
=
torch
.
cat
(
[
torch
.
full
([
Nprime
//
B
,
1
],
ix
,
device
=
x
.
device
,
dtype
=
torch
.
long
)
for
ix
in
range
(
B
)
]
)
geom_feats
=
torch
.
cat
((
geom_feats
,
batch_ix
),
1
)
# filter out points that are outside box
kept
=
(
(
geom_feats
[:,
0
]
>=
0
)
&
(
geom_feats
[:,
0
]
<
self
.
nx
[
0
])
&
(
geom_feats
[:,
1
]
>=
0
)
&
(
geom_feats
[:,
1
]
<
self
.
nx
[
1
])
&
(
geom_feats
[:,
2
]
>=
0
)
&
(
geom_feats
[:,
2
]
<
self
.
nx
[
2
])
)
x
=
x
[
kept
]
geom_feats
=
geom_feats
[
kept
]
# idx = torch.where(kept)[0]
# x = x.index_select(0, idx)
# geom_feats = geom_feats.index_select(0, idx)
x
=
bev_pool
(
x
,
geom_feats
,
B
,
self
.
nx
[
2
],
self
.
nx
[
0
],
self
.
nx
[
1
])
# # collapse Z
# x = x.permute(0, 4, 1, 2, 3).contiguous()
# final = torch.cat(x.unbind(dim=2), 1)
return
x
def
stack_metas
(
self
,
metas
,
key
,
device
,
dtype
):
tensors
=
[]
for
meta
in
metas
:
val
=
meta
[
key
]
if
isinstance
(
val
,
np
.
ndarray
):
val
=
torch
.
from_numpy
(
val
)
elif
isinstance
(
val
,
list
):
val
=
torch
.
stack
([
torch
.
from_numpy
(
v
)
if
isinstance
(
v
,
np
.
ndarray
)
else
v
for
v
in
val
],
dim
=
0
)
tensors
.
append
(
val
)
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
device
=
device
,
dtype
=
dtype
)
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
extract_metas
(
self
,
images
,
img_metas
):
device
=
images
.
device
dtype
=
images
.
dtype
lidar2img
=
self
.
stack_metas
(
img_metas
,
'lidar2img'
,
device
,
dtype
)
camera2ego
=
self
.
stack_metas
(
img_metas
,
'camera2ego'
,
device
,
dtype
)
camera_intrinsics
=
self
.
stack_metas
(
img_metas
,
'camera_intrinsics'
,
device
,
dtype
)
img_aug_matrix
=
self
.
stack_metas
(
img_metas
,
'img_aug_matrix'
,
device
,
dtype
)
lidar2ego
=
self
.
stack_metas
(
img_metas
,
'lidar2ego'
,
device
,
dtype
)
rots
=
camera2ego
[...,
:
3
,
:
3
]
trans
=
camera2ego
[...,
:
3
,
3
]
intrins
=
camera_intrinsics
[...,
:
3
,
:
3
]
post_rots
=
img_aug_matrix
[...,
:
3
,
:
3
]
post_trans
=
img_aug_matrix
[...,
:
3
,
3
]
lidar2ego_rots
=
lidar2ego
[...,
:
3
,
:
3
]
lidar2ego_trans
=
lidar2ego
[...,
:
3
,
3
]
return
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
camera2ego
,
camera_intrinsics
@
force_fp32
()
def
forward
(
self
,
images
,
img_metas
):
B
,
N
,
C
,
fH
,
fW
=
images
.
shape
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
camera2ego
,
camera_intrinsics
=
self
.
extract_metas
(
images
,
img_metas
)
geom
=
self
.
get_geometry_v1
(
fH
,
fW
,
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
img_metas
)
mlp_input
=
self
.
get_mlp_input
(
camera2ego
,
camera_intrinsics
,
post_rots
,
post_trans
)
x
,
depth
=
self
.
get_cam_feats
(
images
,
mlp_input
)
x
=
self
.
bev_pool
(
geom
,
x
)
# x = x.permute(0,1,3,2).contiguous()
return
x
,
depth
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BaseTransformV2
(
BaseModule
):
def
__init__
(
self
,
input_size
,
in_channels
,
out_channels
,
feat_down_sample
,
pc_range
,
voxel_size
,
dbound
,
sid
=
False
,
):
super
(
BaseTransformV2
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
feat_down_sample
=
feat_down_sample
# self.image_size = image_size
# self.feature_size = feature_size
xbound
=
[
pc_range
[
0
],
pc_range
[
3
],
voxel_size
[
0
]]
ybound
=
[
pc_range
[
1
],
pc_range
[
4
],
voxel_size
[
1
]]
zbound
=
[
pc_range
[
2
],
pc_range
[
5
],
voxel_size
[
2
]]
grid_config
=
[
xbound
,
ybound
,
zbound
]
self
.
create_grid_infos
(
*
grid_config
)
self
.
dbound
=
dbound
self
.
sid
=
sid
self
.
frustum
=
self
.
create_frustum
(
dbound
,
input_size
,
feat_down_sample
)
self
.
C
=
out_channels
self
.
D
=
round
((
dbound
[
1
]
-
dbound
[
0
])
/
dbound
[
2
])
self
.
fp16_enabled
=
False
def
create_grid_infos
(
self
,
x
,
y
,
z
,
**
kwargs
):
"""Generate the grid information including the lower bound, interval,
and size.
Args:
x (tuple(float)): Config of grid alone x axis in format of
(lower_bound, upper_bound, interval).
y (tuple(float)): Config of grid alone y axis in format of
(lower_bound, upper_bound, interval).
z (tuple(float)): Config of grid alone z axis in format of
(lower_bound, upper_bound, interval).
**kwargs: Container for other potential parameters
"""
self
.
grid_lower_bound
=
torch
.
Tensor
([
cfg
[
0
]
for
cfg
in
[
x
,
y
,
z
]])
self
.
grid_interval
=
torch
.
Tensor
([
cfg
[
2
]
for
cfg
in
[
x
,
y
,
z
]])
self
.
grid_size
=
torch
.
Tensor
([(
cfg
[
1
]
-
cfg
[
0
])
/
cfg
[
2
]
for
cfg
in
[
x
,
y
,
z
]])
# @force_fp32()
def
create_frustum
(
self
,
depth_cfg
,
input_size
,
downsample
):
"""Generate the frustum template for each image.
Args:
depth_cfg (tuple(float)): Config of grid alone depth axis in format
of (lower_bound, upper_bound, interval).
input_size (tuple(int)): Size of input images in format of (height,
width).
downsample (int): Down sample scale factor from the input size to
the feature size.
"""
H_in
,
W_in
=
input_size
H_feat
,
W_feat
=
H_in
//
downsample
,
W_in
//
downsample
d
=
torch
.
arange
(
*
depth_cfg
,
dtype
=
torch
.
float
)
\
.
view
(
-
1
,
1
,
1
).
expand
(
-
1
,
H_feat
,
W_feat
)
self
.
D
=
d
.
shape
[
0
]
if
self
.
sid
:
d_sid
=
torch
.
arange
(
self
.
D
).
float
()
depth_cfg_t
=
torch
.
tensor
(
depth_cfg
).
float
()
d_sid
=
torch
.
exp
(
torch
.
log
(
depth_cfg_t
[
0
])
+
d_sid
/
(
self
.
D
-
1
)
*
torch
.
log
((
depth_cfg_t
[
1
]
-
1
)
/
depth_cfg_t
[
0
]))
d
=
d_sid
.
view
(
-
1
,
1
,
1
).
expand
(
-
1
,
H_feat
,
W_feat
)
x
=
torch
.
linspace
(
0
,
W_in
-
1
,
W_feat
,
dtype
=
torch
.
float
)
\
.
view
(
1
,
1
,
W_feat
).
expand
(
self
.
D
,
H_feat
,
W_feat
)
y
=
torch
.
linspace
(
0
,
H_in
-
1
,
H_feat
,
dtype
=
torch
.
float
)
\
.
view
(
1
,
H_feat
,
1
).
expand
(
self
.
D
,
H_feat
,
W_feat
)
# D x H x W x 3
return
torch
.
stack
((
x
,
y
,
d
),
-
1
)
def
get_lidar_coor
(
self
,
fH
,
fW
,
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
img_metas
):
B
,
N
,
_
,
_
=
sensor2ego
.
shape
# post-transformation
# B x N x D x H x W x 3
points
=
self
.
frustum
.
to
(
sensor2ego
)
-
post_trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
points
=
torch
.
inverse
(
post_rots
).
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
)
\
.
matmul
(
points
.
unsqueeze
(
-
1
))
# cam_to_ego
points
=
torch
.
cat
(
(
points
[...,
:
2
,
:]
*
points
[...,
2
:
3
,
:],
points
[...,
2
:
3
,
:]),
5
)
combine
=
rots
.
matmul
(
torch
.
inverse
(
intrins
))
points
=
combine
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
).
squeeze
(
-
1
)
points
+=
trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
# ego_to_lidar
points
-=
lidar2ego_trans
.
view
(
B
,
1
,
1
,
1
,
1
,
3
)
points
=
(
torch
.
inverse
(
lidar2ego_rots
)
.
view
(
B
,
1
,
1
,
1
,
1
,
3
,
3
)
.
matmul
(
points
.
unsqueeze
(
-
1
))
.
squeeze
(
-
1
)
)
return
points
@
force_fp32
()
def
get_geometry_v1
(
self
,
fH
,
fW
,
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
img_metas
,
**
kwargs
,
):
B
,
N
,
_
=
trans
.
shape
device
=
trans
.
device
# if self.frustum == None:
# self.frustum = self.create_frustum(fH,fW,img_metas)
# self.frustum = self.frustum.to(device)
# # self.D = self.frustum.shape[0]
# undo post-transformation
# B x N x D x H x W x 3
points
=
self
.
frustum
.
to
(
device
)
-
post_trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
points
=
(
torch
.
inverse
(
post_rots
)
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
)
.
matmul
(
points
.
unsqueeze
(
-
1
))
)
# cam_to_ego
points
=
torch
.
cat
(
(
points
[:,
:,
:,
:,
:,
:
2
]
*
points
[:,
:,
:,
:,
:,
2
:
3
],
points
[:,
:,
:,
:,
:,
2
:
3
],
),
5
,
)
combine
=
rots
.
matmul
(
torch
.
inverse
(
intrins
))
points
=
combine
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
).
squeeze
(
-
1
)
points
+=
trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
)
# ego_to_lidar
points
-=
lidar2ego_trans
.
view
(
B
,
1
,
1
,
1
,
1
,
3
)
points
=
(
torch
.
inverse
(
lidar2ego_rots
)
.
view
(
B
,
1
,
1
,
1
,
1
,
3
,
3
)
.
matmul
(
points
.
unsqueeze
(
-
1
))
.
squeeze
(
-
1
)
)
if
"extra_rots"
in
kwargs
:
extra_rots
=
kwargs
[
"extra_rots"
]
points
=
(
extra_rots
.
view
(
B
,
1
,
1
,
1
,
1
,
3
,
3
)
.
repeat
(
1
,
N
,
1
,
1
,
1
,
1
,
1
)
.
matmul
(
points
.
unsqueeze
(
-
1
))
.
squeeze
(
-
1
)
)
if
"extra_trans"
in
kwargs
:
extra_trans
=
kwargs
[
"extra_trans"
]
points
+=
extra_trans
.
view
(
B
,
1
,
1
,
1
,
1
,
3
).
repeat
(
1
,
N
,
1
,
1
,
1
,
1
)
return
points
@
force_fp32
()
def
get_geometry
(
self
,
fH
,
fW
,
lidar2img
,
img_metas
,
):
B
,
N
,
_
,
_
=
lidar2img
.
shape
device
=
lidar2img
.
device
if
self
.
frustum
==
None
:
self
.
frustum
=
self
.
create_frustum
(
fH
,
fW
,
img_metas
)
self
.
frustum
=
self
.
frustum
.
to
(
device
)
# self.D = self.frustum.shape[0]
points
=
self
.
frustum
.
view
(
1
,
1
,
self
.
D
,
fH
,
fW
,
3
)
\
.
repeat
(
B
,
N
,
1
,
1
,
1
,
1
)
lidar2img
=
lidar2img
.
view
(
B
,
N
,
1
,
1
,
1
,
4
,
4
)
# img2lidar = torch.inverse(lidar2img)
points
=
torch
.
cat
(
(
points
,
torch
.
ones_like
(
points
[...,
:
1
])),
-
1
)
points
=
torch
.
linalg
.
solve
(
lidar2img
.
to
(
torch
.
float32
),
points
.
unsqueeze
(
-
1
).
to
(
torch
.
float32
)).
squeeze
(
-
1
)
# points = torch.matmul(img2lidar.to(torch.float32),
# points.unsqueeze(-1).to(torch.float32)).squeeze(-1)
eps
=
1e-5
points
=
points
[...,
0
:
3
]
/
torch
.
maximum
(
points
[...,
3
:
4
],
torch
.
ones_like
(
points
[...,
3
:
4
])
*
eps
)
return
points
def
get_cam_feats
(
self
,
x
):
raise
NotImplementedError
def
get_mlp_input
(
self
,
sensor2ego
,
intrin
,
post_rot
,
post_tran
,
bda
):
raise
NotImplementedError
def
voxel_pooling_prepare_v2
(
self
,
coor
):
"""Data preparation for voxel pooling.
Args:
coor (torch.tensor): Coordinate of points in the lidar space in
shape (B, N, D, H, W, 3).
Returns:
tuple[torch.tensor]: Rank of the voxel that a point is belong to
in shape (N_Points); Reserved index of points in the depth
space in shape (N_Points). Reserved index of points in the
feature space in shape (N_Points).
"""
B
,
N
,
D
,
H
,
W
,
_
=
coor
.
shape
num_points
=
B
*
N
*
D
*
H
*
W
# record the index of selected points for acceleration purpose
ranks_depth
=
torch
.
range
(
0
,
num_points
-
1
,
dtype
=
torch
.
int
,
device
=
coor
.
device
)
ranks_feat
=
torch
.
range
(
0
,
num_points
//
D
-
1
,
dtype
=
torch
.
int
,
device
=
coor
.
device
)
ranks_feat
=
ranks_feat
.
reshape
(
B
,
N
,
1
,
H
,
W
)
ranks_feat
=
ranks_feat
.
expand
(
B
,
N
,
D
,
H
,
W
).
flatten
()
# convert coordinate into the voxel space
coor
=
((
coor
-
self
.
grid_lower_bound
.
to
(
coor
))
/
self
.
grid_interval
.
to
(
coor
))
coor
=
coor
.
long
().
view
(
num_points
,
3
)
batch_idx
=
torch
.
range
(
0
,
B
-
1
).
reshape
(
B
,
1
).
\
expand
(
B
,
num_points
//
B
).
reshape
(
num_points
,
1
).
to
(
coor
)
coor
=
torch
.
cat
((
coor
,
batch_idx
),
1
)
# filter out points that are outside box
kept
=
(
coor
[:,
0
]
>=
0
)
&
(
coor
[:,
0
]
<
self
.
grid_size
[
0
])
&
\
(
coor
[:,
1
]
>=
0
)
&
(
coor
[:,
1
]
<
self
.
grid_size
[
1
])
&
\
(
coor
[:,
2
]
>=
0
)
&
(
coor
[:,
2
]
<
self
.
grid_size
[
2
])
if
len
(
kept
)
==
0
:
return
None
,
None
,
None
,
None
,
None
coor
,
ranks_depth
,
ranks_feat
=
\
coor
[
kept
],
ranks_depth
[
kept
],
ranks_feat
[
kept
]
# get tensors from the same voxel next to each other
ranks_bev
=
coor
[:,
3
]
*
(
self
.
grid_size
[
2
]
*
self
.
grid_size
[
1
]
*
self
.
grid_size
[
0
])
ranks_bev
+=
coor
[:,
2
]
*
(
self
.
grid_size
[
1
]
*
self
.
grid_size
[
0
])
ranks_bev
+=
coor
[:,
1
]
*
self
.
grid_size
[
0
]
+
coor
[:,
0
]
order
=
ranks_bev
.
argsort
()
ranks_bev
,
ranks_depth
,
ranks_feat
=
\
ranks_bev
[
order
],
ranks_depth
[
order
],
ranks_feat
[
order
]
kept
=
torch
.
ones
(
ranks_bev
.
shape
[
0
],
device
=
ranks_bev
.
device
,
dtype
=
torch
.
bool
)
kept
[
1
:]
=
ranks_bev
[
1
:]
!=
ranks_bev
[:
-
1
]
interval_starts
=
torch
.
where
(
kept
)[
0
].
int
()
if
len
(
interval_starts
)
==
0
:
return
None
,
None
,
None
,
None
,
None
interval_lengths
=
torch
.
zeros_like
(
interval_starts
)
interval_lengths
[:
-
1
]
=
interval_starts
[
1
:]
-
interval_starts
[:
-
1
]
interval_lengths
[
-
1
]
=
ranks_bev
.
shape
[
0
]
-
interval_starts
[
-
1
]
return
ranks_bev
.
int
().
contiguous
(),
ranks_depth
.
int
().
contiguous
(
),
ranks_feat
.
int
().
contiguous
(),
interval_starts
.
int
().
contiguous
(
),
interval_lengths
.
int
().
contiguous
()
@
force_fp32
()
def
voxel_pooling_v2
(
self
,
coor
,
depth
,
feat
):
ranks_bev
,
ranks_depth
,
ranks_feat
,
\
interval_starts
,
interval_lengths
=
\
self
.
voxel_pooling_prepare_v2
(
coor
)
if
ranks_feat
is
None
:
print
(
'warning ---> no points within the predefined '
'bev receptive field'
)
dummy
=
torch
.
zeros
(
size
=
[
feat
.
shape
[
0
],
feat
.
shape
[
2
],
int
(
self
.
grid_size
[
2
]),
int
(
self
.
grid_size
[
0
]),
int
(
self
.
grid_size
[
1
])
]).
to
(
feat
)
dummy
=
torch
.
cat
(
dummy
.
unbind
(
dim
=
2
),
1
)
return
dummy
feat
=
feat
.
permute
(
0
,
1
,
3
,
4
,
2
)
bev_feat_shape
=
(
depth
.
shape
[
0
],
int
(
self
.
grid_size
[
2
]),
int
(
self
.
grid_size
[
1
]),
int
(
self
.
grid_size
[
0
]),
feat
.
shape
[
-
1
])
# (B, Z, Y, X, C)
bev_feat
=
bev_pool_v2
(
depth
,
feat
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
bev_feat_shape
,
interval_starts
,
interval_lengths
)
# collapse Z
# if self.collapse_z:
bev_feat
=
torch
.
cat
(
bev_feat
.
unbind
(
dim
=
2
),
1
)
return
bev_feat
@
force_fp32
()
def
bev_pool
(
self
,
geom_feats
,
x
):
B
,
N
,
D
,
H
,
W
,
C
=
x
.
shape
Nprime
=
B
*
N
*
D
*
H
*
W
# flatten x
x
=
x
.
reshape
(
Nprime
,
C
)
# flatten indices
geom_feats
=
((
geom_feats
-
(
self
.
bx
-
self
.
dx
/
2.0
))
/
self
.
dx
).
long
()
geom_feats
=
geom_feats
.
view
(
Nprime
,
3
)
batch_ix
=
torch
.
cat
(
[
torch
.
full
([
Nprime
//
B
,
1
],
ix
,
device
=
x
.
device
,
dtype
=
torch
.
long
)
for
ix
in
range
(
B
)
]
)
geom_feats
=
torch
.
cat
((
geom_feats
,
batch_ix
),
1
)
# filter out points that are outside box
kept
=
(
(
geom_feats
[:,
0
]
>=
0
)
&
(
geom_feats
[:,
0
]
<
self
.
nx
[
0
])
&
(
geom_feats
[:,
1
]
>=
0
)
&
(
geom_feats
[:,
1
]
<
self
.
nx
[
1
])
&
(
geom_feats
[:,
2
]
>=
0
)
&
(
geom_feats
[:,
2
]
<
self
.
nx
[
2
])
)
x
=
x
[
kept
]
geom_feats
=
geom_feats
[
kept
]
x
=
bev_pool
(
x
,
geom_feats
,
B
,
self
.
nx
[
2
],
self
.
nx
[
0
],
self
.
nx
[
1
])
# collapse Z
final
=
torch
.
cat
(
x
.
unbind
(
dim
=
2
),
1
)
return
final
@
force_fp32
()
def
forward
(
self
,
images
,
img_metas
):
B
,
N
,
C
,
fH
,
fW
=
images
.
shape
lidar2img
=
[]
camera2ego
=
[]
camera_intrinsics
=
[]
img_aug_matrix
=
[]
lidar2ego
=
[]
for
img_meta
in
img_metas
:
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
camera2ego
.
append
(
img_meta
[
'camera2ego'
])
camera_intrinsics
.
append
(
img_meta
[
'camera_intrinsics'
])
img_aug_matrix
.
append
(
img_meta
[
'img_aug_matrix'
])
lidar2ego
.
append
(
img_meta
[
'lidar2ego'
])
lidar2img
=
np
.
asarray
(
lidar2img
)
lidar2img
=
images
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
camera2ego
=
np
.
asarray
(
camera2ego
)
camera2ego
=
images
.
new_tensor
(
camera2ego
)
# (B, N, 4, 4)
camera_intrinsics
=
np
.
asarray
(
camera_intrinsics
)
camera_intrinsics
=
images
.
new_tensor
(
camera_intrinsics
)
# (B, N, 4, 4)
img_aug_matrix
=
np
.
asarray
(
img_aug_matrix
)
img_aug_matrix
=
images
.
new_tensor
(
img_aug_matrix
)
# (B, N, 4, 4)
lidar2ego
=
np
.
asarray
(
lidar2ego
)
lidar2ego
=
images
.
new_tensor
(
lidar2ego
)
# (B, N, 4, 4)
# lidar2cam = torch.linalg.solve(camera2ego, lidar2ego.view(B,1,4,4).repeat(1,N,1,1))
# lidar2oriimg = torch.matmul(camera_intrinsics,lidar2cam)
# mylidar2img = torch.matmul(img_aug_matrix,lidar2oriimg)
rots
=
camera2ego
[...,
:
3
,
:
3
]
trans
=
camera2ego
[...,
:
3
,
3
]
intrins
=
camera_intrinsics
[...,
:
3
,
:
3
]
post_rots
=
img_aug_matrix
[...,
:
3
,
:
3
]
post_trans
=
img_aug_matrix
[...,
:
3
,
3
]
lidar2ego_rots
=
lidar2ego
[...,
:
3
,
:
3
]
lidar2ego_trans
=
lidar2ego
[...,
:
3
,
3
]
sensor_config
=
[
fH
,
fW
,
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
img_metas
]
# coor = self.get_lidar_coor(*sensor_config)
# # tmpgeom = self.get_geometry(
# # fH,
# # fW,
# # mylidar2img,
# # img_metas,
# # )
coor
=
self
.
get_geometry_v1
(
fH
,
fW
,
rots
,
trans
,
intrins
,
post_rots
,
post_trans
,
lidar2ego_rots
,
lidar2ego_trans
,
img_metas
)
mlp_input
=
self
.
get_mlp_input
(
camera2ego
,
camera_intrinsics
,
post_rots
,
post_trans
)
tran_feat
,
depth
=
self
.
get_cam_feats
(
images
,
mlp_input
)
bev_feat
=
self
.
voxel_pooling_v2
(
coor
,
depth
,
tran_feat
)
# x = self.bev_pool(geom, x)
# import ipdb;ipdb.set_trace()
# bev_feat = bev_feat.permute(0,1,3,2).contiguous()
return
bev_feat
,
depth
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
ReLU
,
drop
=
0.0
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
drop1
=
nn
.
Dropout
(
drop
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop2
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop2
(
x
)
return
x
class
SELayer
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
act_layer
=
nn
.
ReLU
,
gate_layer
=
nn
.
Sigmoid
):
super
().
__init__
()
self
.
conv_reduce
=
nn
.
Conv2d
(
channels
,
channels
,
1
,
bias
=
True
)
self
.
act1
=
act_layer
()
self
.
conv_expand
=
nn
.
Conv2d
(
channels
,
channels
,
1
,
bias
=
True
)
self
.
gate
=
gate_layer
()
def
forward
(
self
,
x
,
x_se
):
x_se
=
self
.
conv_reduce
(
x_se
)
x_se
=
self
.
act1
(
x_se
)
x_se
=
self
.
conv_expand
(
x_se
)
return
x
*
self
.
gate
(
x_se
)
class
DepthNet
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
context_channels
,
depth_channels
,
use_dcn
=
True
,
use_aspp
=
True
,
with_cp
=
False
,
aspp_mid_channels
=-
1
,
only_depth
=
False
):
super
(
DepthNet
,
self
).
__init__
()
self
.
reduce_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
only_depth
=
only_depth
or
context_channels
==
0
if
not
self
.
only_depth
:
self
.
context_conv
=
nn
.
Conv2d
(
mid_channels
,
context_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
context_mlp
=
Mlp
(
22
,
mid_channels
,
mid_channels
)
self
.
context_se
=
SELayer
(
mid_channels
)
# NOTE: add camera-aware
self
.
bn
=
nn
.
BatchNorm1d
(
22
)
self
.
depth_mlp
=
Mlp
(
22
,
mid_channels
,
mid_channels
)
self
.
depth_se
=
SELayer
(
mid_channels
)
# NOTE: add camera-aware
depth_conv_list
=
[
BasicBlock
(
mid_channels
,
mid_channels
),
BasicBlock
(
mid_channels
,
mid_channels
),
BasicBlock
(
mid_channels
,
mid_channels
),
]
if
use_aspp
:
if
aspp_mid_channels
<
0
:
aspp_mid_channels
=
mid_channels
depth_conv_list
.
append
(
ASPP
(
mid_channels
,
aspp_mid_channels
))
if
use_dcn
:
depth_conv_list
.
append
(
build_conv_layer
(
cfg
=
dict
(
type
=
'DCN'
,
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
padding
=
1
,
groups
=
4
,
im2col_step
=
128
,
)))
depth_conv_list
.
append
(
nn
.
Conv2d
(
mid_channels
,
depth_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
self
.
depth_conv
=
nn
.
Sequential
(
*
depth_conv_list
)
self
.
with_cp
=
with_cp
def
forward
(
self
,
x
,
mlp_input
):
mlp_input
=
self
.
bn
(
mlp_input
.
reshape
(
-
1
,
mlp_input
.
shape
[
-
1
]))
x
=
self
.
reduce_conv
(
x
)
if
not
self
.
only_depth
:
context_se
=
self
.
context_mlp
(
mlp_input
)[...,
None
,
None
]
context
=
self
.
context_se
(
x
,
context_se
)
context
=
self
.
context_conv
(
context
)
depth_se
=
self
.
depth_mlp
(
mlp_input
)[...,
None
,
None
]
depth
=
self
.
depth_se
(
x
,
depth_se
)
if
self
.
with_cp
:
depth
=
checkpoint
(
self
.
depth_conv
,
depth
)
else
:
depth
=
self
.
depth_conv
(
depth
)
if
not
self
.
only_depth
:
return
torch
.
cat
([
depth
,
context
],
dim
=
1
)
else
:
return
depth
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BEVFormerEncoderDepth
(
BEVFormerEncoder
):
def
__init__
(
self
,
*
args
,
in_channels
=
256
,
out_channels
=
256
,
feat_down_sample
=
32
,
loss_depth_weight
=
3.0
,
depthnet_cfg
=
dict
(),
grid_config
=
None
,
**
kwargs
):
super
(
BEVFormerEncoderDepth
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
fp16_enabled
=
False
self
.
loss_depth_weight
=
loss_depth_weight
self
.
feat_down_sample
=
feat_down_sample
self
.
grid_config
=
grid_config
self
.
D
=
int
((
grid_config
[
'depth'
][
1
]
-
grid_config
[
'depth'
][
0
])
/
grid_config
[
'depth'
][
2
])
self
.
depth_net
=
DepthNet
(
in_channels
,
in_channels
,
0
,
self
.
D
,
**
depthnet_cfg
)
@
auto_fp16
()
def
forward
(
self
,
bev_query
,
key
,
value
,
*
args
,
mlvl_feats
=
None
,
bev_h
=
None
,
bev_w
=
None
,
bev_pos
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
valid_ratios
=
None
,
prev_bev
=
None
,
shift
=
0.
,
**
kwargs
):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
bev_embed
=
super
().
forward
(
bev_query
,
key
,
value
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
prev_bev
,
shift
=
shift
,
**
kwargs
)
# import ipdb; ipdb.set_trace()
images
=
mlvl_feats
[
0
]
img_metas
=
kwargs
[
'img_metas'
]
B
,
N
,
C
,
fH
,
fW
=
images
.
shape
lidar2img
=
[]
camera2ego
=
[]
camera_intrinsics
=
[]
img_aug_matrix
=
[]
lidar2ego
=
[]
for
img_meta
in
img_metas
:
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
camera2ego
.
append
(
img_meta
[
'camera2ego'
])
camera_intrinsics
.
append
(
img_meta
[
'camera_intrinsics'
])
img_aug_matrix
.
append
(
img_meta
[
'img_aug_matrix'
])
lidar2ego
.
append
(
img_meta
[
'lidar2ego'
])
lidar2img
=
np
.
asarray
(
lidar2img
)
lidar2img
=
images
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
camera2ego
=
np
.
asarray
(
camera2ego
)
camera2ego
=
images
.
new_tensor
(
camera2ego
)
# (B, N, 4, 4)
camera_intrinsics
=
np
.
asarray
(
camera_intrinsics
)
camera_intrinsics
=
images
.
new_tensor
(
camera_intrinsics
)
# (B, N, 4, 4)
img_aug_matrix
=
np
.
asarray
(
img_aug_matrix
)
img_aug_matrix
=
images
.
new_tensor
(
img_aug_matrix
)
# (B, N, 4, 4)
lidar2ego
=
np
.
asarray
(
lidar2ego
)
lidar2ego
=
images
.
new_tensor
(
lidar2ego
)
# (B, N, 4, 4)
rots
=
camera2ego
[...,
:
3
,
:
3
]
trans
=
camera2ego
[...,
:
3
,
3
]
intrins
=
camera_intrinsics
[...,
:
3
,
:
3
]
post_rots
=
img_aug_matrix
[...,
:
3
,
:
3
]
post_trans
=
img_aug_matrix
[...,
:
3
,
3
]
lidar2ego_rots
=
lidar2ego
[...,
:
3
,
:
3
]
lidar2ego_trans
=
lidar2ego
[...,
:
3
,
3
]
mlp_input
=
self
.
get_mlp_input
(
camera2ego
,
camera_intrinsics
,
post_rots
,
post_trans
)
depth
=
self
.
get_cam_feats
(
images
,
mlp_input
)
ret_dict
=
dict
(
bev
=
bev_embed
[
'bev'
],
depth
=
depth
,
)
# import ipdb; ipdb.set_trace()
return
ret_dict
@
force_fp32
()
def
get_cam_feats
(
self
,
x
,
mlp_input
):
B
,
N
,
C
,
fH
,
fW
=
x
.
shape
x
=
x
.
view
(
B
*
N
,
C
,
fH
,
fW
)
x
=
self
.
depth_net
(
x
,
mlp_input
)
depth
=
x
[:,
:
self
.
D
].
softmax
(
dim
=
1
)
depth
=
depth
.
view
(
B
,
N
,
self
.
D
,
fH
,
fW
)
return
depth
def
get_downsampled_gt_depth
(
self
,
gt_depths
):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B
,
N
,
H
,
W
=
gt_depths
.
shape
gt_depths
=
gt_depths
.
view
(
B
*
N
,
H
//
self
.
feat_down_sample
,
self
.
feat_down_sample
,
W
//
self
.
feat_down_sample
,
self
.
feat_down_sample
,
1
)
gt_depths
=
gt_depths
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
).
contiguous
()
gt_depths
=
gt_depths
.
view
(
-
1
,
self
.
feat_down_sample
*
self
.
feat_down_sample
)
# 把gt_depth做feat_down_sample倍数的采样
gt_depths_tmp
=
torch
.
where
(
gt_depths
==
0.0
,
1e5
*
torch
.
ones_like
(
gt_depths
),
gt_depths
)
# 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值
gt_depths
=
torch
.
min
(
gt_depths_tmp
,
dim
=-
1
).
values
gt_depths
=
gt_depths
.
view
(
B
*
N
,
H
//
self
.
feat_down_sample
,
W
//
self
.
feat_down_sample
)
gt_depths
=
(
gt_depths
-
(
self
.
grid_config
[
'depth'
][
0
]
-
self
.
grid_config
[
'depth'
][
2
]))
/
self
.
grid_config
[
'depth'
][
2
]
gt_depths
=
torch
.
where
((
gt_depths
<
self
.
D
+
1
)
&
(
gt_depths
>=
0.0
),
gt_depths
,
torch
.
zeros_like
(
gt_depths
))
gt_depths
=
F
.
one_hot
(
gt_depths
.
long
(),
num_classes
=
self
.
D
+
1
).
view
(
-
1
,
self
.
D
+
1
)[:,
1
:]
return
gt_depths
.
float
()
@
force_fp32
()
def
get_depth_loss
(
self
,
depth_labels
,
depth_preds
):
# import pdb;pdb.set_trace()
if
depth_preds
is
None
:
return
0
depth_labels
=
self
.
get_downsampled_gt_depth
(
depth_labels
)
depth_preds
=
depth_preds
.
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
().
view
(
-
1
,
self
.
D
)
# fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss
# import pdb;pdb.set_trace()
fg_mask
=
depth_labels
>
0.0
# 只计算有深度的前景的深度loss
depth_labels
=
depth_labels
[
fg_mask
]
depth_preds
=
depth_preds
[
fg_mask
]
with
autocast
(
enabled
=
False
):
depth_loss
=
F
.
binary_cross_entropy
(
depth_preds
,
depth_labels
,
reduction
=
'none'
,
).
sum
()
/
max
(
1.0
,
fg_mask
.
sum
())
# if depth_loss <= 0.:
# import pdb;pdb.set_trace()
return
self
.
loss_depth_weight
*
depth_loss
def
get_mlp_input
(
self
,
sensor2ego
,
intrin
,
post_rot
,
post_tran
):
B
,
N
,
_
,
_
=
sensor2ego
.
shape
mlp_input
=
torch
.
stack
([
intrin
[:,
:,
0
,
0
],
intrin
[:,
:,
1
,
1
],
intrin
[:,
:,
0
,
2
],
intrin
[:,
:,
1
,
2
],
post_rot
[:,
:,
0
,
0
],
post_rot
[:,
:,
0
,
1
],
post_tran
[:,
:,
0
],
post_rot
[:,
:,
1
,
0
],
post_rot
[:,
:,
1
,
1
],
post_tran
[:,
:,
1
],
],
dim
=-
1
)
sensor2ego
=
sensor2ego
[:,:,:
3
,:].
reshape
(
B
,
N
,
-
1
)
mlp_input
=
torch
.
cat
([
mlp_input
,
sensor2ego
],
dim
=-
1
)
return
mlp_input
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
LSSTransform
(
BaseTransform
):
def
__init__
(
self
,
in_channels
,
out_channels
,
feat_down_sample
,
pc_range
,
voxel_size
,
dbound
,
downsample
=
1
,
loss_depth_weight
=
3.0
,
depthnet_cfg
=
dict
(),
grid_config
=
None
,
):
super
(
LSSTransform
,
self
).
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
feat_down_sample
=
feat_down_sample
,
pc_range
=
pc_range
,
voxel_size
=
voxel_size
,
dbound
=
dbound
,
)
# import pdb;pdb.set_trace()
self
.
loss_depth_weight
=
loss_depth_weight
self
.
grid_config
=
grid_config
self
.
depth_net
=
DepthNet
(
in_channels
,
in_channels
,
self
.
C
,
self
.
D
,
**
depthnet_cfg
)
if
downsample
>
1
:
assert
downsample
==
2
,
downsample
self
.
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
stride
=
downsample
,
padding
=
1
,
bias
=
False
,
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
True
),
)
else
:
self
.
downsample
=
nn
.
Identity
()
#@torch.compile(mode="max-autotune-no-cudagraphs")
@
force_fp32
()
def
get_cam_feats
(
self
,
x
,
mlp_input
):
B
,
N
,
C
,
fH
,
fW
=
x
.
shape
x
=
x
.
view
(
B
*
N
,
C
,
fH
,
fW
)
x
=
self
.
depth_net
(
x
,
mlp_input
)
depth
=
x
[:,
:
self
.
D
].
softmax
(
dim
=
1
)
x
=
depth
.
unsqueeze
(
1
)
*
x
[:,
self
.
D
:
(
self
.
D
+
self
.
C
)].
unsqueeze
(
2
)
x
=
x
.
view
(
B
,
N
,
self
.
C
,
self
.
D
,
fH
,
fW
)
x
=
x
.
permute
(
0
,
1
,
3
,
4
,
5
,
2
)
depth
=
depth
.
view
(
B
,
N
,
self
.
D
,
fH
,
fW
)
return
x
,
depth
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
down_sample
(
self
,
x
):
input
=
x
B
,
N
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
).
contiguous
()
x
=
torch
.
cat
(
x
.
unbind
(
dim
=
2
),
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
).
contiguous
()
return
self
.
downsample
(
x
)
def
forward
(
self
,
images
,
img_metas
):
x
,
depth
=
super
().
forward
(
images
,
img_metas
)
# x = self.downsample(x)
x
=
self
.
down_sample
(
x
)
ret_dict
=
dict
(
bev
=
x
,
depth
=
depth
,
)
return
ret_dict
def
get_downsampled_gt_depth
(
self
,
gt_depths
):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B
,
N
,
H
,
W
=
gt_depths
.
shape
gt_depths
=
gt_depths
.
view
(
B
*
N
,
H
//
self
.
feat_down_sample
,
self
.
feat_down_sample
,
W
//
self
.
feat_down_sample
,
self
.
feat_down_sample
,
1
)
gt_depths
=
gt_depths
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
).
contiguous
()
gt_depths
=
gt_depths
.
view
(
-
1
,
self
.
feat_down_sample
*
self
.
feat_down_sample
)
# 把gt_depth做feat_down_sample倍数的采样
gt_depths_tmp
=
torch
.
where
(
gt_depths
==
0.0
,
1e5
*
torch
.
ones_like
(
gt_depths
),
gt_depths
)
# 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值
gt_depths
=
torch
.
min
(
gt_depths_tmp
,
dim
=-
1
).
values
gt_depths
=
gt_depths
.
view
(
B
*
N
,
H
//
self
.
feat_down_sample
,
W
//
self
.
feat_down_sample
)
gt_depths
=
(
gt_depths
-
(
self
.
grid_config
[
'depth'
][
0
]
-
self
.
grid_config
[
'depth'
][
2
]))
/
self
.
grid_config
[
'depth'
][
2
]
gt_depths
=
torch
.
where
((
gt_depths
<
self
.
D
+
1
)
&
(
gt_depths
>=
0.0
),
gt_depths
,
torch
.
zeros_like
(
gt_depths
))
gt_depths
=
F
.
one_hot
(
gt_depths
.
long
(),
num_classes
=
self
.
D
+
1
).
view
(
-
1
,
self
.
D
+
1
)[:,
1
:]
return
gt_depths
.
float
()
@
force_fp32
()
def
get_depth_loss
(
self
,
depth_labels
,
depth_preds
):
# import pdb;pdb.set_trace()
if
depth_preds
is
None
:
return
0
depth_labels
=
self
.
get_downsampled_gt_depth
(
depth_labels
)
depth_preds
=
depth_preds
.
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
().
view
(
-
1
,
self
.
D
)
# fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss
# import pdb;pdb.set_trace()
fg_mask
=
depth_labels
>
0.0
# 只计算有深度的前景的深度loss
depth_labels
=
depth_labels
[
fg_mask
]
depth_preds
=
depth_preds
[
fg_mask
]
with
autocast
(
enabled
=
False
):
depth_loss
=
F
.
binary_cross_entropy
(
depth_preds
,
depth_labels
,
reduction
=
'none'
,
).
sum
()
/
max
(
1.0
,
fg_mask
.
sum
())
# if depth_loss <= 0.:
# import pdb;pdb.set_trace()
return
self
.
loss_depth_weight
*
depth_loss
#@torch.compile(mode="max-autotune-no-cudagraphs")
def
get_mlp_input
(
self
,
sensor2ego
,
intrin
,
post_rot
,
post_tran
):
B
,
N
,
_
,
_
=
sensor2ego
.
shape
mlp_input
=
torch
.
stack
([
intrin
[:,
:,
0
,
0
],
intrin
[:,
:,
1
,
1
],
intrin
[:,
:,
0
,
2
],
intrin
[:,
:,
1
,
2
],
post_rot
[:,
:,
0
,
0
],
post_rot
[:,
:,
0
,
1
],
post_tran
[:,
:,
0
],
post_rot
[:,
:,
1
,
0
],
post_rot
[:,
:,
1
,
1
],
post_tran
[:,
:,
1
],
],
dim
=-
1
)
sensor2ego
=
sensor2ego
[:,:,:
3
,:].
reshape
(
B
,
N
,
-
1
)
mlp_input
=
torch
.
cat
([
mlp_input
,
sensor2ego
],
dim
=-
1
)
return
mlp_input
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
LSSTransformV2
(
BaseTransformV2
):
def
__init__
(
self
,
input_size
,
in_channels
,
out_channels
,
feat_down_sample
,
pc_range
,
voxel_size
,
dbound
,
downsample
=
1
,
loss_depth_weight
=
3.0
,
depthnet_cfg
=
dict
(),
grid_config
=
None
,
sid
=
False
,
):
super
(
LSSTransformV2
,
self
).
__init__
(
input_size
=
input_size
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
feat_down_sample
=
feat_down_sample
,
pc_range
=
pc_range
,
voxel_size
=
voxel_size
,
dbound
=
dbound
,
sid
=
sid
,
)
self
.
loss_depth_weight
=
loss_depth_weight
self
.
grid_config
=
grid_config
self
.
depth_net
=
DepthNet
(
self
.
in_channels
,
self
.
in_channels
,
self
.
C
,
self
.
D
,
**
depthnet_cfg
)
if
downsample
>
1
:
assert
downsample
==
2
,
downsample
self
.
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
stride
=
downsample
,
padding
=
1
,
bias
=
False
,
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
True
),
)
else
:
self
.
downsample
=
nn
.
Identity
()
@
force_fp32
()
def
get_cam_feats
(
self
,
x
,
mlp_input
):
B
,
N
,
C
,
fH
,
fW
=
x
.
shape
x
=
x
.
view
(
B
*
N
,
C
,
fH
,
fW
)
x
=
self
.
depth_net
(
x
,
mlp_input
)
depth
=
x
[:,
:
self
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
D
:
(
self
.
D
+
self
.
C
)]
tran_feat
=
tran_feat
.
view
(
B
,
N
,
self
.
C
,
fH
,
fW
)
# x = x.permute(0, 1, 3, 4, 5, 2)
depth
=
depth
.
view
(
B
,
N
,
self
.
D
,
fH
,
fW
)
return
tran_feat
,
depth
def
forward
(
self
,
images
,
img_metas
):
x
,
depth
=
super
().
forward
(
images
,
img_metas
)
x
=
self
.
downsample
(
x
)
ret_dict
=
dict
(
bev
=
x
,
depth
=
depth
,
)
return
ret_dict
def
get_downsampled_gt_depth
(
self
,
gt_depths
):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B
,
N
,
H
,
W
=
gt_depths
.
shape
gt_depths
=
gt_depths
.
view
(
B
*
N
,
H
//
self
.
feat_down_sample
,
self
.
feat_down_sample
,
W
//
self
.
feat_down_sample
,
self
.
feat_down_sample
,
1
)
gt_depths
=
gt_depths
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
).
contiguous
()
gt_depths
=
gt_depths
.
view
(
-
1
,
self
.
feat_down_sample
*
self
.
feat_down_sample
)
# 把gt_depth做feat_down_sample倍数的采样
gt_depths_tmp
=
torch
.
where
(
gt_depths
==
0.0
,
1e5
*
torch
.
ones_like
(
gt_depths
),
gt_depths
)
# 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值
gt_depths
=
torch
.
min
(
gt_depths_tmp
,
dim
=-
1
).
values
gt_depths
=
gt_depths
.
view
(
B
*
N
,
H
//
self
.
feat_down_sample
,
W
//
self
.
feat_down_sample
)
gt_depths
=
(
gt_depths
-
(
self
.
grid_config
[
'depth'
][
0
]
-
self
.
grid_config
[
'depth'
][
2
]))
/
self
.
grid_config
[
'depth'
][
2
]
gt_depths
=
torch
.
where
((
gt_depths
<
self
.
D
+
1
)
&
(
gt_depths
>=
0.0
),
gt_depths
,
torch
.
zeros_like
(
gt_depths
))
gt_depths
=
F
.
one_hot
(
gt_depths
.
long
(),
num_classes
=
self
.
D
+
1
).
view
(
-
1
,
self
.
D
+
1
)[:,
1
:]
return
gt_depths
.
float
()
@
force_fp32
()
def
get_depth_loss
(
self
,
depth_labels
,
depth_preds
):
# import pdb;pdb.set_trace()
if
depth_preds
is
None
:
return
0
depth_labels
=
self
.
get_downsampled_gt_depth
(
depth_labels
)
depth_preds
=
depth_preds
.
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
().
view
(
-
1
,
self
.
D
)
# fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss
# import pdb;pdb.set_trace()
fg_mask
=
depth_labels
>
0.0
# 只计算有深度的前景的深度loss
depth_labels
=
depth_labels
[
fg_mask
]
depth_preds
=
depth_preds
[
fg_mask
]
with
autocast
(
enabled
=
False
):
depth_loss
=
F
.
binary_cross_entropy
(
depth_preds
,
depth_labels
,
reduction
=
'none'
,
).
sum
()
/
max
(
1.0
,
fg_mask
.
sum
())
# if depth_loss <= 0.:
# import pdb;pdb.set_trace()
return
self
.
loss_depth_weight
*
depth_loss
def
get_mlp_input
(
self
,
sensor2ego
,
intrin
,
post_rot
,
post_tran
):
B
,
N
,
_
,
_
=
sensor2ego
.
shape
mlp_input
=
torch
.
stack
([
intrin
[:,
:,
0
,
0
],
intrin
[:,
:,
1
,
1
],
intrin
[:,
:,
0
,
2
],
intrin
[:,
:,
1
,
2
],
post_rot
[:,
:,
0
,
0
],
post_rot
[:,
:,
0
,
1
],
post_tran
[:,
:,
0
],
post_rot
[:,
:,
1
,
0
],
post_rot
[:,
:,
1
,
1
],
post_tran
[:,
:,
1
],
],
dim
=-
1
)
sensor2ego
=
sensor2ego
[:,:,:
3
,:].
reshape
(
B
,
N
,
-
1
)
mlp_input
=
torch
.
cat
([
mlp_input
,
sensor2ego
],
dim
=-
1
)
return
mlp_input
class
_ASPPModule
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
planes
,
kernel_size
,
padding
,
dilation
,
BatchNorm
):
super
(
_ASPPModule
,
self
).
__init__
()
self
.
atrous_conv
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
False
)
self
.
bn
=
BatchNorm
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
_init_weight
()
def
forward
(
self
,
x
):
x
=
self
.
atrous_conv
(
x
)
x
=
self
.
bn
(
x
)
return
self
.
relu
(
x
)
def
_init_weight
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
class
ASPP
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
mid_channels
=
256
,
BatchNorm
=
nn
.
BatchNorm2d
):
super
(
ASPP
,
self
).
__init__
()
dilations
=
[
1
,
6
,
12
,
18
]
self
.
aspp1
=
_ASPPModule
(
inplanes
,
mid_channels
,
1
,
padding
=
0
,
dilation
=
dilations
[
0
],
BatchNorm
=
BatchNorm
)
self
.
aspp2
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
1
],
dilation
=
dilations
[
1
],
BatchNorm
=
BatchNorm
)
self
.
aspp3
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
2
],
dilation
=
dilations
[
2
],
BatchNorm
=
BatchNorm
)
self
.
aspp4
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
3
],
dilation
=
dilations
[
3
],
BatchNorm
=
BatchNorm
)
self
.
global_avg_pool
=
nn
.
Sequential
(
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
nn
.
Conv2d
(
inplanes
,
mid_channels
,
1
,
stride
=
1
,
bias
=
False
),
BatchNorm
(
mid_channels
),
nn
.
ReLU
(),
)
self
.
conv1
=
nn
.
Conv2d
(
int
(
mid_channels
*
5
),
inplanes
,
1
,
bias
=
False
)
self
.
bn1
=
BatchNorm
(
inplanes
)
self
.
relu
=
nn
.
ReLU
()
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
_init_weight
()
def
forward
(
self
,
x
):
x1
=
self
.
aspp1
(
x
)
x2
=
self
.
aspp2
(
x
)
x3
=
self
.
aspp3
(
x
)
x4
=
self
.
aspp4
(
x
)
x5
=
self
.
global_avg_pool
(
x
)
x5
=
F
.
interpolate
(
x5
,
size
=
x4
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
torch
.
cat
((
x1
,
x2
,
x3
,
x4
,
x5
),
dim
=
1
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
return
self
.
dropout
(
x
)
def
_init_weight
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/geometry_kernel_attention.py
0 → 100644
View file @
19472568
import
warnings
import
time
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
.ops.geometric_kernel_attn
import
GeometricKernelAttentionFunc
@
ATTENTION
.
register_module
()
class
GeometrySptialCrossAttention
(
BaseModule
):
"""An attention module used in BEVFormer.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_cams (int): The number of cameras
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
deformable_attention: (dict): The config for the deformable attention used in SCA.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_cams
=
6
,
pc_range
=
None
,
dropout
=
0.1
,
init_cfg
=
None
,
batch_first
=
False
,
attention
=
dict
(
type
=
'MSDeformableAttention3D'
,
embed_dims
=
256
,
num_levels
=
4
),
**
kwargs
):
super
(
GeometrySptialCrossAttention
,
self
).
__init__
(
init_cfg
)
self
.
init_cfg
=
init_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
self
.
attention
=
build_attention
(
attention
)
self
.
embed_dims
=
embed_dims
self
.
num_cams
=
num_cams
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
batch_first
=
batch_first
self
.
init_weight
()
def
init_weight
(
self
):
"""Default initialization for Parameters of Module."""
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points_cam'
))
def
forward
(
self
,
query
,
key
,
value
,
residual
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
reference_points_cam
=
None
,
bev_mask
=
None
,
level_start_index
=
None
,
flag
=
'encoder'
,
**
kwargs
):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. (B, N, C, H, W)
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, 4),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
key
is
None
:
key
=
query
if
value
is
None
:
value
=
key
if
residual
is
None
:
inp_residual
=
query
slots
=
torch
.
zeros_like
(
query
)
if
query_pos
is
not
None
:
query
=
query
+
query_pos
bs
,
num_query
,
_
=
query
.
size
()
D
=
reference_points_cam
.
size
(
3
)
indexes
=
[]
for
i
,
mask_per_img
in
enumerate
(
bev_mask
):
index_query_per_img
=
mask_per_img
[
0
].
sum
(
-
1
).
nonzero
().
squeeze
(
-
1
)
indexes
.
append
(
index_query_per_img
)
max_len
=
max
([
len
(
each
)
for
each
in
indexes
])
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch
=
query
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
])
reference_points_rebatch
=
reference_points_cam
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
D
,
2
])
for
j
in
range
(
bs
):
for
i
,
reference_points_per_img
in
enumerate
(
reference_points_cam
):
index_query_per_img
=
indexes
[
i
]
queries_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
query
[
j
,
index_query_per_img
]
reference_points_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
reference_points_per_img
[
j
,
index_query_per_img
]
num_cams
,
l
,
bs
,
embed_dims
=
key
.
shape
key
=
key
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
value
=
value
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
queries
=
self
.
attention
(
query
=
queries_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
self
.
embed_dims
),
key
=
key
,
value
=
value
,
reference_points
=
reference_points_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
D
,
2
),
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
).
view
(
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
)
for
j
in
range
(
bs
):
for
i
,
index_query_per_img
in
enumerate
(
indexes
):
slots
[
j
,
index_query_per_img
]
+=
queries
[
j
,
i
,
:
len
(
index_query_per_img
)]
count
=
bev_mask
.
sum
(
-
1
)
>
0
count
=
count
.
permute
(
1
,
2
,
0
).
sum
(
-
1
)
count
=
torch
.
clamp
(
count
,
min
=
1.0
)
slots
=
slots
/
count
[...,
None
]
slots
=
self
.
output_proj
(
slots
)
return
self
.
dropout
(
slots
)
+
inp_residual
@
ATTENTION
.
register_module
()
class
GeometryKernelAttention
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
kernel_size
=
(
3
,
3
),
dilation
=
1
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
batch_first
=
batch_first
self
.
output_proj
=
None
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
# 4
self
.
num_levels
=
num_levels
# 4 num_heads -> num_z_anchors
self
.
num_heads
=
num_heads
self
.
kernel_size
=
kernel_size
self
.
num_points
=
kernel_size
[
0
]
*
kernel_size
[
1
]
# self.sampling_offsets = nn.Linear(
# embed_dims, num_heads * num_levels * self.num_points * 2)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_levels
*
self
.
num_points
*
self
.
num_heads
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
grid_h
,
grid_w
=
kernel_size
y
=
(
torch
.
arange
(
grid_h
)
-
grid_h
//
2
)
*
dilation
x
=
(
torch
.
arange
(
grid_w
)
-
grid_w
//
2
)
*
dilation
offsets
=
torch
.
stack
(
torch
.
meshgrid
(
x
,
y
)).
permute
(
1
,
2
,
0
).
reshape
(
grid_h
*
grid_w
,
2
)
self
.
register_buffer
(
"grid_offsets"
,
offsets
,
persistent
=
False
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
# constant_init(self.sampling_offsets, 0.)
# thetas = torch.arange(
# self.num_heads,
# dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
# grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
# grid_init = (grid_init /
# grid_init.abs().max(-1, keepdim=True)[0]).view(
# self.num_heads, 1, 1,
# 2).repeat(1, self.num_levels, self.num_points, 1)
# for i in range(self.num_points):
# grid_init[:, :, i, :] *= i + 1
# self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward_kernel_multihead_attention
(
self
,
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
):
# value: (bs, n, d)
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, dim)
spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
# print(value.shape, sampling_locations.shape, attention_weights.shape)
# print(value.shape)
bs
,
num_keys
,
num_heads
,
dim
=
value
.
shape
# (bs * num_heads * num_keys, d)
# torch.cuda.synchronize()
# start2 = time.perf_counter()
value
=
value
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
*
num_heads
*
num_keys
,
dim
)
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
sampling_locations
.
shape
with
torch
.
no_grad
():
sampling_index
=
sampling_locations
.
new_zeros
(
(
bs
,
num_queries
,
num_heads
,
num_levels
,
num_points
)).
to
(
value
.
device
)
start_index
=
0
for
level
,
(
H_
,
W_
)
in
enumerate
(
spatial_shapes
):
# xy or yx?
sampling_locations
[:,
:,
:,
level
,
:,
0
].
clamp_
(
min
=
0
,
max
=
W_
-
1
)
sampling_locations
[:,
:,
:,
level
,
:,
1
].
clamp_
(
min
=
0
,
max
=
H_
-
1
)
sampling_index
[:,
:,
:,
level
]
=
start_index
+
sampling_locations
[:,
:,
:,
level
,
:,
0
]
\
+
sampling_locations
[:,
:,
:,
level
,
:,
1
]
*
W_
start_index
+=
H_
*
W_
# print(start_index)
# head index, (bs, head, num_quries,)
sampling_index
=
sampling_index
.
transpose
(
1
,
2
).
reshape
(
bs
,
num_heads
,
-
1
)
sampling_index
=
sampling_index
+
\
(
torch
.
arange
(
num_heads
).
to
(
sampling_index
)
*
num_keys
).
view
(
1
,
num_heads
,
1
)
# batch index
sampling_index
=
sampling_index
.
reshape
(
bs
,
-
1
)
+
(
torch
.
arange
(
bs
).
to
(
sampling_index
)
*
num_keys
*
num_heads
).
view
(
bs
,
1
)
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention (index): {:.3f} ms".format(
# (end-start)*1000))
# torch.cuda.synchronize()
# start = time.perf_counter()
sampling_value
=
value
[
sampling_index
].
view
(
bs
,
num_heads
,
num_queries
,
num_levels
*
num_points
,
dim
)
# print(sampling_value.shape)
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
,
num_heads
,
num_queries
,
num_levels
*
num_points
,
1
)
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention (sample): {:.3f} ms".format(
# (end-start)*1000))
# # (bs*head, num_queries, num_levels * num_points, d) -> (bs, head, num_queries, d)
# torch.cuda.synchronize()
# start = time.perf_counter()
output
=
(
sampling_value
*
attention_weights
).
sum
(
-
2
).
transpose
(
1
,
2
).
contiguous
()
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention (matmul): {:.3f} ms".format(
# (end-start)*1000))
# print('x;', output.shape)
return
output
.
view
(
bs
,
num_queries
,
-
1
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
# sampling_offsets = self.sampling_offsets(query).view(
# bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
# bs, num_query, num_heads, num_levels, num_points
# bs, q, 4, 4, K^2
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
"""
For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
For each referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points.
"""
with
torch
.
no_grad
():
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
bs
,
num_query
,
num_Z_anchors
,
xy
=
reference_points
.
shape
# from IPython import embed; embed()
# (K,2) -> (1, 1, 1, 1, k, 2) -> (bs, q, nz, l, k, 2)
offsets
=
self
.
grid_offsets
[
None
,
None
,
None
,
None
]
# (bs, q, nz, 1, xy) -> (bs, q, z, l, 2)
reference_points
=
reference_points
[:,
:,
:,
None
,
:]
*
offset_normalizer
# from IPython import embed;embed()
# (bs, q, nz, l, k, xy)
sampling_locations
=
(
reference_points
[:,
:,
:,
:,
None
,
:]
+
offsets
).
round
().
long
()
# sampling_offsets = sampling_offsets / \
# offset_normalizer[None, None, None, :, None, :]
# (bs, q, 4(z), 4, K^2, 2)
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
=
sampling_locations
.
shape
# sampling_offsets = sampling_offsets.view(
# bs, num_query, num_heads, num_levels, num_all_points // num_Z_anchors, num_Z_anchors, xy)
# sampling_locations = reference_points + sampling_offsets
# bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape
# assert num_all_points == num_points * num_Z_anchors
# sampling_locations = sampling_locations.view(
# bs, num_query, num_heads, num_levels, num_all_points, xy)
elif
reference_points
.
shape
[
-
1
]
==
4
:
assert
False
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2
# attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
# import pdb;pdb.set_trace()
# output = self.forward_kernel_multihead_attention(
# value, spatial_shapes, sampling_locations, attention_weights)
# torch.cuda.synchronize()
# start = time.perf_counter()
output
=
GeometricKernelAttentionFunc
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
.
contiguous
(),
attention_weights
,
self
.
im2col_step
)
# if torch.cuda.is_available() and value.is_cuda:
# if value.dtype == torch.float16:
# MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
# else:
# MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
# output = MultiScaleDeformableAttnFunction.apply(
# value, spatial_shapes, level_start_index, sampling_locations,
# attention_weights, self.im2col_step)
# else:
# output = multi_scale_deformable_attn_pytorch(
# value, spatial_shapes, sampling_locations, attention_weights)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention: {:.3f} ms".format((end-start)*1000))
return
output
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/__init__.py
0 → 100644
View file @
19472568
from
.function
import
GeometricKernelAttentionFunc
\ No newline at end of file
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/function/__init__.py
0 → 100644
View file @
19472568
from
.geometric_kernel_attn_func
import
GeometricKernelAttentionFunc
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/function/geometric_kernel_attn_func.py
0 → 100644
View file @
19472568
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
torch
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
import
GeometricKernelAttention
as
GKA
class
GeometricKernelAttentionFunc
(
Function
):
@
staticmethod
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
ctx
.
im2col_step
=
im2col_step
output
=
GKA
.
geometric_kernel_attn_cuda_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
,
grad_attn_weight
=
\
GKA
.
geometric_kernel_attn_cuda_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
,
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
None
,
grad_attn_weight
,
None
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/setup.py
0 → 100644
View file @
19472568
import
os
import
glob
import
torch
from
torch.utils.cpp_extension
import
CUDA_HOME
from
torch.utils.cpp_extension
import
CppExtension
from
torch.utils.cpp_extension
import
CUDAExtension
from
setuptools
import
find_packages
from
setuptools
import
setup
requirements
=
[
"torch"
,
"torchvision"
]
def
get_extensions
():
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
extensions_dir
=
os
.
path
.
join
(
this_dir
,
"src"
)
main_file
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"*.cpp"
))
# source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"*.cu"
))
sources
=
main_file
extension
=
CppExtension
extra_compile_args
=
{
"cxx"
:
[]}
define_macros
=
[]
if
1
:
# if torch.cuda.is_available() and CUDA_HOME is not None:
extension
=
CUDAExtension
sources
+=
source_cuda
define_macros
+=
[(
"WITH_CUDA"
,
None
)]
extra_compile_args
[
"nvcc"
]
=
[
"-DCUDA_HAS_FP16=1"
,
"-D__CUDA_NO_HALF_OPERATORS__"
,
"-D__CUDA_NO_HALF_CONVERSIONS__"
,
"-D__CUDA_NO_HALF2_OPERATORS__"
,
]
else
:
raise
NotImplementedError
(
'Cuda is not availabel'
)
sources
=
[
os
.
path
.
join
(
extensions_dir
,
s
)
for
s
in
sources
]
include_dirs
=
[
extensions_dir
]
ext_modules
=
[
extension
(
"GeometricKernelAttention"
,
sources
,
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
,
)
]
return
ext_modules
setup
(
name
=
"GeometricKernelAttention"
,
version
=
"1.0"
,
author
=
"Tianheng Cheng"
,
url
=
"https://github.com/hustvl"
,
description
=
"PyTorch Wrapper for CUDA Functions of Multi-Scale Geometric Kernel Attention"
,
packages
=
find_packages
(
exclude
=
(
"configs"
,
"tests"
,)),
ext_modules
=
get_extensions
(),
cmdclass
=
{
"build_ext"
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
)
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/src/geometric_kernel_attn.h
0 → 100644
View file @
19472568
#pragma once
// #include "cpu/ms_deform_attn_cpu.h"
// #ifdef WITH_CUDA
#include "geometric_kernel_attn_cuda.h"
at
::
Tensor
geometric_kernel_attn_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
if
(
value
.
type
().
is_cuda
())
{
return
geometric_kernel_attn_cuda_forward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
im2col_step
);
}
AT_ERROR
(
"Not implemented on the CPU"
);
}
std
::
vector
<
at
::
Tensor
>
geometric_kernel_attn_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
at
::
Tensor
&
grad_output
,
const
int
im2col_step
)
{
if
(
value
.
type
().
is_cuda
())
{
return
geometric_kernel_attn_cuda_backward
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
im2col_step
);
}
AT_ERROR
(
"Not implemented on the CPU"
);
}
docker-hub/MapTRv2/MapTR/projects/mmdet3d_plugin/maptr/modules/ops/geometric_kernel_attn/src/geometric_kernel_attn_cuda.cu
0 → 100644
View file @
19472568
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <vector>
#include "geometric_kernel_attn_cuda_kernel.cuh"
at
::
Tensor
geometric_kernel_attn_cuda_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"value must be a CUDA tensor"
);
AT_ASSERTM
(
spatial_shapes
.
type
().
is_cuda
(),
"spatial_shapes must be a CUDA tensor"
);
AT_ASSERTM
(
level_start_index
.
type
().
is_cuda
(),
"level_start_index must be a CUDA tensor"
);
AT_ASSERTM
(
sampling_loc
.
type
().
is_cuda
(),
"sampling_loc must be a CUDA tensor"
);
AT_ASSERTM
(
attn_weight
.
type
().
is_cuda
(),
"attn_weight must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
spatial_size
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_query
=
sampling_loc
.
size
(
1
);
const
int
num_point
=
sampling_loc
.
size
(
4
);
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch(%d) must divide im2col_step(%d)"
,
batch
,
im2col_step_
);
auto
output
=
at
::
zeros
({
batch
,
num_query
,
num_heads
,
channels
},
value
.
options
());
const
int
batch_n
=
im2col_step_
;
auto
output_n
=
output
.
view
({
batch
/
im2col_step_
,
batch_n
,
num_query
,
num_heads
,
channels
});
auto
per_value_size
=
spatial_size
*
num_heads
*
channels
;
auto
per_sample_loc_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
2
;
auto
per_attn_weight_size
=
num_query
*
num_heads
*
num_levels
*
num_point
;
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
auto
columns
=
output_n
.
select
(
0
,
n
);
AT_DISPATCH_FLOATING_TYPES
(
value
.
type
(),
"multiscale_kernel_attn_forward_cuda"
,
([
&
]
{
multiscale_kernel_attn_forward_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
value
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
spatial_shapes
.
data
<
int64_t
>
(),
level_start_index
.
data
<
int64_t
>
(),
sampling_loc
.
data
<
int64_t
>
()
+
n
*
im2col_step_
*
per_sample_loc_size
,
attn_weight
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_attn_weight_size
,
batch_n
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
columns
.
data
<
scalar_t
>
());
}));
}
output
=
output
.
view
({
batch
,
num_query
,
num_heads
*
channels
});
return
output
;
}
std
::
vector
<
at
::
Tensor
>
geometric_kernel_attn_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
at
::
Tensor
&
grad_output
,
const
int
im2col_step
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"value must be a CUDA tensor"
);
AT_ASSERTM
(
spatial_shapes
.
type
().
is_cuda
(),
"spatial_shapes must be a CUDA tensor"
);
AT_ASSERTM
(
level_start_index
.
type
().
is_cuda
(),
"level_start_index must be a CUDA tensor"
);
AT_ASSERTM
(
sampling_loc
.
type
().
is_cuda
(),
"sampling_loc must be a CUDA tensor"
);
AT_ASSERTM
(
attn_weight
.
type
().
is_cuda
(),
"attn_weight must be a CUDA tensor"
);
AT_ASSERTM
(
grad_output
.
type
().
is_cuda
(),
"grad_output must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
spatial_size
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_query
=
sampling_loc
.
size
(
1
);
const
int
num_point
=
sampling_loc
.
size
(
4
);
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch(%d) must divide im2col_step(%d)"
,
batch
,
im2col_step_
);
auto
grad_value
=
at
::
zeros_like
(
value
);
auto
grad_attn_weight
=
at
::
zeros_like
(
attn_weight
);
const
int
batch_n
=
im2col_step_
;
auto
per_value_size
=
spatial_size
*
num_heads
*
channels
;
auto
per_sample_loc_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
2
;
auto
per_attn_weight_size
=
num_query
*
num_heads
*
num_levels
*
num_point
;
auto
grad_output_n
=
grad_output
.
view
({
batch
/
im2col_step_
,
batch_n
,
num_query
,
num_heads
,
channels
});
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
auto
grad_output_g
=
grad_output_n
.
select
(
0
,
n
);
AT_DISPATCH_FLOATING_TYPES
(
value
.
type
(),
"multiscale_kernel_attn_backward_cuda"
,
([
&
]
{
multiscale_kernel_attn_backward_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
grad_output_g
.
data
<
scalar_t
>
(),
value
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
spatial_shapes
.
data
<
int64_t
>
(),
level_start_index
.
data
<
int64_t
>
(),
sampling_loc
.
data
<
int64_t
>
()
+
n
*
im2col_step_
*
per_sample_loc_size
,
attn_weight
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_attn_weight_size
,
batch_n
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
grad_attn_weight
.
data
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_attn_weight_size
);
}));
}
return
{
grad_value
,
grad_attn_weight
};
}
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment