Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
lishj6
BEVFomer
Commits
4cd43886
Commit
4cd43886
authored
Sep 01, 2025
by
lishj6
🏸
Browse files
init
parent
a9a1fe81
Changes
207
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4330 additions
and
0 deletions
+4330
-0
projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
+2
-0
projects/mmdet3d_plugin/bevformer/dense_heads/bev_head.py
projects/mmdet3d_plugin/bevformer/dense_heads/bev_head.py
+132
-0
projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
...ts/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
+683
-0
projects/mmdet3d_plugin/bevformer/detectors/__init__.py
projects/mmdet3d_plugin/bevformer/detectors/__init__.py
+3
-0
projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
+295
-0
projects/mmdet3d_plugin/bevformer/detectors/bevformerV2.py
projects/mmdet3d_plugin/bevformer/detectors/bevformerV2.py
+269
-0
projects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py
...ects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py
+89
-0
projects/mmdet3d_plugin/bevformer/hooks/__init__.py
projects/mmdet3d_plugin/bevformer/hooks/__init__.py
+1
-0
projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
+13
-0
projects/mmdet3d_plugin/bevformer/modules/__init__.py
projects/mmdet3d_plugin/bevformer/modules/__init__.py
+8
-0
projects/mmdet3d_plugin/bevformer/modules/custom_base_transformer_layer.py
...plugin/bevformer/modules/custom_base_transformer_layer.py
+260
-0
projects/mmdet3d_plugin/bevformer/modules/decoder.py
projects/mmdet3d_plugin/bevformer/modules/decoder.py
+345
-0
projects/mmdet3d_plugin/bevformer/modules/encoder.py
projects/mmdet3d_plugin/bevformer/modules/encoder.py
+591
-0
projects/mmdet3d_plugin/bevformer/modules/group_attention.py
projects/mmdet3d_plugin/bevformer/modules/group_attention.py
+162
-0
projects/mmdet3d_plugin/bevformer/modules/multi_scale_deformable_attn_function.py
...bevformer/modules/multi_scale_deformable_attn_function.py
+163
-0
projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
...det3d_plugin/bevformer/modules/spatial_cross_attention.py
+399
-0
projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
...det3d_plugin/bevformer/modules/temporal_self_attention.py
+272
-0
projects/mmdet3d_plugin/bevformer/modules/transformer.py
projects/mmdet3d_plugin/bevformer/modules/transformer.py
+289
-0
projects/mmdet3d_plugin/bevformer/modules/transformerV2.py
projects/mmdet3d_plugin/bevformer/modules/transformerV2.py
+353
-0
projects/mmdet3d_plugin/bevformer/runner/__init__.py
projects/mmdet3d_plugin/bevformer/runner/__init__.py
+1
-0
No files found.
projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
0 → 100644
View file @
4cd43886
from
.bevformer_head
import
BEVFormerHead
,
BEVFormerHead_GroupDETR
from
.bev_head
import
BEVHead
projects/mmdet3d_plugin/bevformer/dense_heads/bev_head.py
0 → 100644
View file @
4cd43886
import
copy
from
re
import
I
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmdet.core
import
(
multi_apply
,
multi_apply
,
reduce_mean
)
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
HEADS
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet3d.core.bbox.coders
import
build_bbox_coder
from
traitlets
import
import_item
from
projects.mmdet3d_plugin.core.bbox.util
import
normalize_bbox
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.runner
import
BaseModule
,
force_fp32
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
import
numpy
as
np
import
mmcv
import
cv2
as
cv
from
projects.mmdet3d_plugin.bevformer.modules
import
PerceptionTransformerBEVEncoder
from
mmdet.models.utils
import
build_transformer
from
mmdet3d.models.builder
import
build_head
from
mmdet3d.models.dense_heads.free_anchor3d_head
import
FreeAnchor3DHead
@
HEADS
.
register_module
()
class
BEVHead
(
BaseModule
):
def
__init__
(
self
,
bev_h
,
bev_w
,
pc_range
,
embed_dims
,
transformer
,
positional_encoding
:
dict
,
pts_bbox_head_3d
:
dict
,
init_cfg
=
None
,
**
kwargs
,
):
super
(
BEVHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
bev_h
=
bev_h
self
.
bev_w
=
bev_w
self
.
embed_dims
=
embed_dims
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
self
.
transformer
:
PerceptionTransformerBEVEncoder
=
build_transformer
(
transformer
)
self
.
positional_encoding
=
build_positional_encoding
(
positional_encoding
)
pts_bbox_head_3d
.
update
(
kwargs
)
self
.
pts_bbox_head_3d
=
build_head
(
pts_bbox_head_3d
)
self
.
real_w
=
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
]
self
.
real_h
=
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
]
self
.
_init_layers
()
def
init_weights
(
self
):
"""Initialize weights of the Multi View BEV Encoder"""
self
.
transformer
.
init_weights
()
def
_init_layers
(
self
):
"""Initialize classification branch and regression branch of head."""
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
@
force_fp32
(
apply_to
=
(
'mlvl_feats'
,
'pred_bev'
))
def
forward
(
self
,
mlvl_feats
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
):
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
bev_embed
=
self
.
transformer
(
mlvl_feats
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)
if
only_bev
:
return
bev_embed
bev_feature
=
bev_embed
.
permute
(
0
,
2
,
1
).
reshape
(
bs
,
self
.
embed_dims
,
self
.
bev_h
,
self
.
bev_w
)
ret
=
{}
ret
[
'pred'
]
=
self
.
pts_bbox_head_3d
([
bev_feature
,])
if
not
self
.
training
:
ret
[
'bev_embed'
]
=
bev_embed
return
ret
@
force_fp32
(
apply_to
=
(
'ret'
))
def
loss
(
self
,
gt_bboxes_list
,
gt_labels_list
,
ret
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
assert
gt_bboxes_ignore
is
None
return
self
.
pts_bbox_head_3d
.
loss
(
gt_bboxes_list
,
gt_labels_list
,
ret
[
'pred'
],
gt_bboxes_ignore
=
gt_bboxes_ignore
,
img_metas
=
img_metas
)
@
force_fp32
(
apply_to
=
(
'ret'
))
def
get_bboxes
(
self
,
ret
,
img_metas
,
rescale
=
False
):
return
self
.
pts_bbox_head_3d
.
get_bboxes
(
ret
[
'pred'
],
img_metas
)
@
HEADS
.
register_module
()
class
FreeAnchor3DHeadV2
(
FreeAnchor3DHead
):
@
force_fp32
(
apply_to
=
(
'pred'
))
def
loss
(
self
,
gt_bboxes_list
,
gt_labels_list
,
pred
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
cls_scores
,
bbox_preds
,
dir_cls_preds
=
pred
return
super
().
loss
(
cls_scores
,
bbox_preds
,
dir_cls_preds
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
gt_bboxes_ignore
)
@
force_fp32
(
apply_to
=
(
'pred'
))
def
get_bboxes
(
self
,
pred
,
img_metas
,
rescale
=
False
):
cls_scores
,
bbox_preds
,
dir_cls_preds
=
pred
return
super
().
get_bboxes
(
cls_scores
,
bbox_preds
,
dir_cls_preds
,
img_metas
,
cfg
=
None
,
rescale
=
rescale
)
\ No newline at end of file
projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
0 → 100644
View file @
4cd43886
import
copy
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmdet.core
import
(
multi_apply
,
multi_apply
,
reduce_mean
)
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
HEADS
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet3d.core.bbox.coders
import
build_bbox_coder
from
projects.mmdet3d_plugin.core.bbox.util
import
normalize_bbox
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
HEADS
.
register_module
()
class
BEVFormerHead
(
DETRHead
):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def
__init__
(
self
,
*
args
,
with_box_refine
=
False
,
as_two_stage
=
False
,
transformer
=
None
,
bbox_coder
=
None
,
num_cls_fcs
=
2
,
code_weights
=
None
,
bev_h
=
30
,
bev_w
=
30
,
**
kwargs
):
self
.
bev_h
=
bev_h
self
.
bev_w
=
bev_w
self
.
fp16_enabled
=
False
self
.
with_box_refine
=
with_box_refine
self
.
as_two_stage
=
as_two_stage
if
self
.
as_two_stage
:
transformer
[
'as_two_stage'
]
=
self
.
as_two_stage
if
'code_size'
in
kwargs
:
self
.
code_size
=
kwargs
[
'code_size'
]
else
:
self
.
code_size
=
10
if
code_weights
is
not
None
:
self
.
code_weights
=
code_weights
else
:
self
.
code_weights
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
0.2
,
0.2
]
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
pc_range
=
self
.
bbox_coder
.
pc_range
self
.
real_w
=
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
]
self
.
real_h
=
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
]
self
.
num_cls_fcs
=
num_cls_fcs
-
1
super
(
BEVFormerHead
,
self
).
__init__
(
*
args
,
transformer
=
transformer
,
**
kwargs
)
self
.
code_weights
=
nn
.
Parameter
(
torch
.
tensor
(
self
.
code_weights
,
requires_grad
=
False
),
requires_grad
=
False
)
def
_init_layers
(
self
):
"""Initialize classification branch and regression branch of head."""
cls_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
cls_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
cls_branch
.
append
(
nn
.
LayerNorm
(
self
.
embed_dims
))
cls_branch
.
append
(
nn
.
ReLU
(
inplace
=
True
))
cls_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
cls_out_channels
))
fc_cls
=
nn
.
Sequential
(
*
cls_branch
)
reg_branch
=
[]
for
_
in
range
(
self
.
num_reg_fcs
):
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
embed_dims
))
reg_branch
.
append
(
nn
.
ReLU
())
reg_branch
.
append
(
Linear
(
self
.
embed_dims
,
self
.
code_size
))
reg_branch
=
nn
.
Sequential
(
*
reg_branch
)
def
_get_clones
(
module
,
N
):
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred
=
(
self
.
transformer
.
decoder
.
num_layers
+
1
)
if
\
self
.
as_two_stage
else
self
.
transformer
.
decoder
.
num_layers
if
self
.
with_box_refine
:
self
.
cls_branches
=
_get_clones
(
fc_cls
,
num_pred
)
self
.
reg_branches
=
_get_clones
(
reg_branch
,
num_pred
)
else
:
self
.
cls_branches
=
nn
.
ModuleList
(
[
fc_cls
for
_
in
range
(
num_pred
)])
self
.
reg_branches
=
nn
.
ModuleList
(
[
reg_branch
for
_
in
range
(
num_pred
)])
if
not
self
.
as_two_stage
:
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
self
.
query_embedding
=
nn
.
Embedding
(
self
.
num_query
,
self
.
embed_dims
*
2
)
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
self
.
transformer
.
init_weights
()
if
self
.
loss_cls
.
use_sigmoid
:
bias_init
=
bias_init_with_prob
(
0.01
)
for
m
in
self
.
cls_branches
:
nn
.
init
.
constant_
(
m
[
-
1
].
bias
,
bias_init
)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
))
def
forward
(
self
,
mlvl_feats
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
\
shape [nb_dec, bs, num_query, cls_out_channels]. Note
\
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression
\
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy).
\
Shape [nb_dec, bs, num_query, 9].
"""
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
object_query_embeds
=
self
.
query_embedding
.
weight
.
to
(
dtype
)
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
if
only_bev
:
# only use encoder to obtain BEV features, TODO: refine the workaround
return
self
.
transformer
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)
else
:
outputs
=
self
.
transformer
(
mlvl_feats
,
bev_queries
,
object_query_embeds
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
reg_branches
=
self
.
reg_branches
if
self
.
with_box_refine
else
None
,
# noqa:E501
cls_branches
=
self
.
cls_branches
if
self
.
as_two_stage
else
None
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
)
bev_embed
,
hs
,
init_reference
,
inter_references
=
outputs
hs
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs_classes
=
[]
outputs_coords
=
[]
for
lvl
in
range
(
hs
.
shape
[
0
]):
if
lvl
==
0
:
reference
=
init_reference
else
:
reference
=
inter_references
[
lvl
-
1
]
reference
=
inverse_sigmoid
(
reference
)
outputs_class
=
self
.
cls_branches
[
lvl
](
hs
[
lvl
])
tmp
=
self
.
reg_branches
[
lvl
](
hs
[
lvl
])
# TODO: check the shape of reference
assert
reference
.
shape
[
-
1
]
==
3
tmp
[...,
0
:
2
]
+=
reference
[...,
0
:
2
]
tmp
[...,
0
:
2
]
=
tmp
[...,
0
:
2
].
sigmoid
()
tmp
[...,
4
:
5
]
+=
reference
[...,
2
:
3
]
tmp
[...,
4
:
5
]
=
tmp
[...,
4
:
5
].
sigmoid
()
tmp
[...,
0
:
1
]
=
(
tmp
[...,
0
:
1
]
*
(
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
])
+
self
.
pc_range
[
0
])
tmp
[...,
1
:
2
]
=
(
tmp
[...,
1
:
2
]
*
(
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
])
+
self
.
pc_range
[
1
])
tmp
[...,
4
:
5
]
=
(
tmp
[...,
4
:
5
]
*
(
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
])
+
self
.
pc_range
[
2
])
# TODO: check if using sigmoid
outputs_coord
=
tmp
outputs_classes
.
append
(
outputs_class
)
outputs_coords
.
append
(
outputs_coord
)
outputs_classes
=
torch
.
stack
(
outputs_classes
)
outputs_coords
=
torch
.
stack
(
outputs_coords
)
outs
=
{
'bev_embed'
:
bev_embed
,
'all_cls_scores'
:
outputs_classes
,
'all_bbox_preds'
:
outputs_coords
,
'enc_cls_scores'
:
None
,
'enc_bbox_preds'
:
None
,
}
return
outs
def
_get_target_single
(
self
,
cls_score
,
bbox_pred
,
gt_labels
,
gt_bboxes
,
gt_bboxes_ignore
=
None
):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_bboxes
=
bbox_pred
.
size
(
0
)
# assigner and sampler
gt_c
=
gt_bboxes
.
shape
[
-
1
]
assign_result
=
self
.
assigner
.
assign
(
bbox_pred
,
cls_score
,
gt_bboxes
,
gt_labels
,
gt_bboxes_ignore
)
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
bbox_pred
,
gt_bboxes
)
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
# label targets
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,),
self
.
num_classes
,
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
label_weights
=
gt_bboxes
.
new_ones
(
num_bboxes
)
# bbox targets
bbox_targets
=
torch
.
zeros_like
(
bbox_pred
)[...,
:
gt_c
]
bbox_weights
=
torch
.
zeros_like
(
bbox_pred
)
bbox_weights
[
pos_inds
]
=
1.0
# DETR
bbox_targets
[
pos_inds
]
=
sampling_result
.
pos_gt_bboxes
return
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pos_inds
,
neg_inds
)
def
get_targets
(
self
,
cls_scores_list
,
bbox_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
gt_bboxes_ignore_list
=
None
):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all
\
images.
- bbox_targets_list (list[Tensor]): BBox targets for all
\
images.
- bbox_weights_list (list[Tensor]): BBox weights for all
\
images.
- num_total_pos (int): Number of positive samples in all
\
images.
- num_total_neg (int): Number of negative samples in all
\
images.
"""
assert
gt_bboxes_ignore_list
is
None
,
\
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs
=
len
(
cls_scores_list
)
gt_bboxes_ignore_list
=
[
gt_bboxes_ignore_list
for
_
in
range
(
num_imgs
)
]
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
gt_labels_list
,
gt_bboxes_list
,
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
def
loss_single
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes_list
,
gt_labels_list
,
gt_bboxes_ignore_list
=
None
):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs
=
cls_scores
.
size
(
0
)
cls_scores_list
=
[
cls_scores
[
i
]
for
i
in
range
(
num_imgs
)]
bbox_preds_list
=
[
bbox_preds
[
i
]
for
i
in
range
(
num_imgs
)]
cls_reg_targets
=
self
.
get_targets
(
cls_scores_list
,
bbox_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
gt_bboxes_ignore_list
)
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
labels
=
torch
.
cat
(
labels_list
,
0
)
label_weights
=
torch
.
cat
(
label_weights_list
,
0
)
bbox_targets
=
torch
.
cat
(
bbox_targets_list
,
0
)
bbox_weights
=
torch
.
cat
(
bbox_weights_list
,
0
)
# classification loss
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_scores
.
new_tensor
([
cls_avg_factor
]))
cls_avg_factor
=
max
(
cls_avg_factor
,
1
)
loss_cls
=
self
.
loss_cls
(
cls_scores
,
labels
,
label_weights
,
avg_factor
=
cls_avg_factor
)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos
=
loss_cls
.
new_tensor
([
num_total_pos
])
num_total_pos
=
torch
.
clamp
(
reduce_mean
(
num_total_pos
),
min
=
1
).
item
()
# regression L1 loss
bbox_preds
=
bbox_preds
.
reshape
(
-
1
,
bbox_preds
.
size
(
-
1
))
normalized_bbox_targets
=
normalize_bbox
(
bbox_targets
,
self
.
pc_range
)
isnotnan
=
torch
.
isfinite
(
normalized_bbox_targets
).
all
(
dim
=-
1
)
bbox_weights
=
bbox_weights
*
self
.
code_weights
loss_bbox
=
self
.
loss_bbox
(
bbox_preds
[
isnotnan
,
:
10
],
normalized_bbox_targets
[
isnotnan
,
:
10
],
bbox_weights
[
isnotnan
,
:
10
],
avg_factor
=
num_total_pos
)
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
loss_cls
=
torch
.
nan_to_num
(
loss_cls
)
loss_bbox
=
torch
.
nan_to_num
(
loss_bbox
)
return
loss_cls
,
loss_bbox
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
loss
(
self
,
gt_bboxes_list
,
gt_labels_list
,
preds_dicts
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
all_cls_scores
=
preds_dicts
[
'all_cls_scores'
]
all_bbox_preds
=
preds_dicts
[
'all_bbox_preds'
]
enc_cls_scores
=
preds_dicts
[
'enc_cls_scores'
]
enc_bbox_preds
=
preds_dicts
[
'enc_bbox_preds'
]
num_dec_layers
=
len
(
all_cls_scores
)
device
=
gt_labels_list
[
0
].
device
gt_bboxes_list
=
[
torch
.
cat
(
(
gt_bboxes
.
gravity_center
,
gt_bboxes
.
tensor
[:,
3
:]),
dim
=
1
).
to
(
device
)
for
gt_bboxes
in
gt_bboxes_list
]
all_gt_bboxes_list
=
[
gt_bboxes_list
for
_
in
range
(
num_dec_layers
)]
all_gt_labels_list
=
[
gt_labels_list
for
_
in
range
(
num_dec_layers
)]
all_gt_bboxes_ignore_list
=
[
gt_bboxes_ignore
for
_
in
range
(
num_dec_layers
)
]
losses_cls
,
losses_bbox
=
multi_apply
(
self
.
loss_single
,
all_cls_scores
,
all_bbox_preds
,
all_gt_bboxes_list
,
all_gt_labels_list
,
all_gt_bboxes_ignore_list
)
loss_dict
=
dict
()
# loss of proposal generated from encode feature map.
if
enc_cls_scores
is
not
None
:
binary_labels_list
=
[
torch
.
zeros_like
(
gt_labels_list
[
i
])
for
i
in
range
(
len
(
all_gt_labels_list
))
]
enc_loss_cls
,
enc_losses_bbox
=
\
self
.
loss_single
(
enc_cls_scores
,
enc_bbox_preds
,
gt_bboxes_list
,
binary_labels_list
,
gt_bboxes_ignore
)
loss_dict
[
'enc_loss_cls'
]
=
enc_loss_cls
loss_dict
[
'enc_loss_bbox'
]
=
enc_losses_bbox
# loss from the last decoder layer
loss_dict
[
'loss_cls'
]
=
losses_cls
[
-
1
]
loss_dict
[
'loss_bbox'
]
=
losses_bbox
[
-
1
]
# loss from other decoder layers
num_dec_layer
=
0
for
loss_cls_i
,
loss_bbox_i
in
zip
(
losses_cls
[:
-
1
],
losses_bbox
[:
-
1
]):
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_cls'
]
=
loss_cls_i
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_bbox'
]
=
loss_bbox_i
num_dec_layer
+=
1
return
loss_dict
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
get_bboxes
(
self
,
preds_dicts
,
img_metas
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
preds_dicts
=
self
.
bbox_coder
.
decode
(
preds_dicts
)
num_samples
=
len
(
preds_dicts
)
ret_list
=
[]
for
i
in
range
(
num_samples
):
preds
=
preds_dicts
[
i
]
bboxes
=
preds
[
'bboxes'
]
bboxes
[:,
2
]
=
bboxes
[:,
2
]
-
bboxes
[:,
5
]
*
0.5
code_size
=
bboxes
.
shape
[
-
1
]
bboxes
=
img_metas
[
i
][
'box_type_3d'
](
bboxes
,
code_size
)
scores
=
preds
[
'scores'
]
labels
=
preds
[
'labels'
]
ret_list
.
append
([
bboxes
,
scores
,
labels
])
return
ret_list
@
HEADS
.
register_module
()
class
BEVFormerHead_GroupDETR
(
BEVFormerHead
):
def
__init__
(
self
,
*
args
,
group_detr
=
1
,
**
kwargs
):
self
.
group_detr
=
group_detr
assert
'num_query'
in
kwargs
kwargs
[
'num_query'
]
=
group_detr
*
kwargs
[
'num_query'
]
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
mlvl_feats
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
):
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
object_query_embeds
=
self
.
query_embedding
.
weight
.
to
(
dtype
)
if
not
self
.
training
:
# NOTE: Only difference to bevformer head
object_query_embeds
=
object_query_embeds
[:
self
.
num_query
//
self
.
group_detr
]
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
if
only_bev
:
return
self
.
transformer
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)
else
:
outputs
=
self
.
transformer
(
mlvl_feats
,
bev_queries
,
object_query_embeds
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
reg_branches
=
self
.
reg_branches
if
self
.
with_box_refine
else
None
,
# noqa:E501
cls_branches
=
self
.
cls_branches
if
self
.
as_two_stage
else
None
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
)
bev_embed
,
hs
,
init_reference
,
inter_references
=
outputs
hs
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs_classes
=
[]
outputs_coords
=
[]
for
lvl
in
range
(
hs
.
shape
[
0
]):
if
lvl
==
0
:
reference
=
init_reference
else
:
reference
=
inter_references
[
lvl
-
1
]
reference
=
inverse_sigmoid
(
reference
)
outputs_class
=
self
.
cls_branches
[
lvl
](
hs
[
lvl
])
tmp
=
self
.
reg_branches
[
lvl
](
hs
[
lvl
])
assert
reference
.
shape
[
-
1
]
==
3
tmp
[...,
0
:
2
]
+=
reference
[...,
0
:
2
]
tmp
[...,
0
:
2
]
=
tmp
[...,
0
:
2
].
sigmoid
()
tmp
[...,
4
:
5
]
+=
reference
[...,
2
:
3
]
tmp
[...,
4
:
5
]
=
tmp
[...,
4
:
5
].
sigmoid
()
tmp
[...,
0
:
1
]
=
(
tmp
[...,
0
:
1
]
*
(
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
])
+
self
.
pc_range
[
0
])
tmp
[...,
1
:
2
]
=
(
tmp
[...,
1
:
2
]
*
(
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
])
+
self
.
pc_range
[
1
])
tmp
[...,
4
:
5
]
=
(
tmp
[...,
4
:
5
]
*
(
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
])
+
self
.
pc_range
[
2
])
outputs_coord
=
tmp
outputs_classes
.
append
(
outputs_class
)
outputs_coords
.
append
(
outputs_coord
)
outputs_classes
=
torch
.
stack
(
outputs_classes
)
outputs_coords
=
torch
.
stack
(
outputs_coords
)
outs
=
{
'bev_embed'
:
bev_embed
,
'all_cls_scores'
:
outputs_classes
,
'all_bbox_preds'
:
outputs_coords
,
'enc_cls_scores'
:
None
,
'enc_bbox_preds'
:
None
,
}
return
outs
def
loss
(
self
,
gt_bboxes_list
,
gt_labels_list
,
preds_dicts
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
gt_bboxes_ignore
is
None
,
\
f
'
{
self
.
__class__
.
__name__
}
only supports '
\
f
'for gt_bboxes_ignore setting to None.'
all_cls_scores
=
preds_dicts
[
'all_cls_scores'
]
all_bbox_preds
=
preds_dicts
[
'all_bbox_preds'
]
enc_cls_scores
=
preds_dicts
[
'enc_cls_scores'
]
enc_bbox_preds
=
preds_dicts
[
'enc_bbox_preds'
]
assert
enc_cls_scores
is
None
and
enc_bbox_preds
is
None
num_dec_layers
=
len
(
all_cls_scores
)
device
=
gt_labels_list
[
0
].
device
gt_bboxes_list
=
[
torch
.
cat
(
(
gt_bboxes
.
gravity_center
,
gt_bboxes
.
tensor
[:,
3
:]),
dim
=
1
).
to
(
device
)
for
gt_bboxes
in
gt_bboxes_list
]
all_gt_bboxes_list
=
[
gt_bboxes_list
for
_
in
range
(
num_dec_layers
)]
all_gt_labels_list
=
[
gt_labels_list
for
_
in
range
(
num_dec_layers
)]
all_gt_bboxes_ignore_list
=
[
gt_bboxes_ignore
for
_
in
range
(
num_dec_layers
)
]
loss_dict
=
dict
()
loss_dict
[
'loss_cls'
]
=
0
loss_dict
[
'loss_bbox'
]
=
0
for
num_dec_layer
in
range
(
all_cls_scores
.
shape
[
0
]
-
1
):
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_cls'
]
=
0
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_bbox'
]
=
0
num_query_per_group
=
self
.
num_query
//
self
.
group_detr
for
group_index
in
range
(
self
.
group_detr
):
group_query_start
=
group_index
*
num_query_per_group
group_query_end
=
(
group_index
+
1
)
*
num_query_per_group
group_cls_scores
=
all_cls_scores
[:,
:,
group_query_start
:
group_query_end
,
:]
group_bbox_preds
=
all_bbox_preds
[:,
:,
group_query_start
:
group_query_end
,
:]
losses_cls
,
losses_bbox
=
multi_apply
(
self
.
loss_single
,
group_cls_scores
,
group_bbox_preds
,
all_gt_bboxes_list
,
all_gt_labels_list
,
all_gt_bboxes_ignore_list
)
loss_dict
[
'loss_cls'
]
+=
losses_cls
[
-
1
]
/
self
.
group_detr
loss_dict
[
'loss_bbox'
]
+=
losses_bbox
[
-
1
]
/
self
.
group_detr
# loss from other decoder layers
num_dec_layer
=
0
for
loss_cls_i
,
loss_bbox_i
in
zip
(
losses_cls
[:
-
1
],
losses_bbox
[:
-
1
]):
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_cls'
]
+=
loss_cls_i
/
self
.
group_detr
loss_dict
[
f
'd
{
num_dec_layer
}
.loss_bbox'
]
+=
loss_bbox_i
/
self
.
group_detr
num_dec_layer
+=
1
return
loss_dict
\ No newline at end of file
projects/mmdet3d_plugin/bevformer/detectors/__init__.py
0 → 100644
View file @
4cd43886
from
.bevformer
import
BEVFormer
from
.bevformer_fp16
import
BEVFormer_fp16
from
.bevformerV2
import
BEVFormerV2
\ No newline at end of file
projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
torch
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
import
time
import
copy
import
numpy
as
np
import
mmdet3d
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
@
DETECTORS
.
register_module
()
class
BEVFormer
(
MVXTwoStageDetector
):
"""BEVFormer.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def
__init__
(
self
,
use_grid_mask
=
False
,
pts_voxel_layer
=
None
,
pts_voxel_encoder
=
None
,
pts_middle_encoder
=
None
,
pts_fusion_layer
=
None
,
img_backbone
=
None
,
pts_backbone
=
None
,
img_neck
=
None
,
pts_neck
=
None
,
pts_bbox_head
=
None
,
img_roi_head
=
None
,
img_rpn_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
video_test_mode
=
False
):
super
(
BEVFormer
,
self
).
__init__
(
pts_voxel_layer
,
pts_voxel_encoder
,
pts_middle_encoder
,
pts_fusion_layer
,
img_backbone
,
pts_backbone
,
img_neck
,
pts_neck
,
pts_bbox_head
,
img_roi_head
,
img_rpn_head
,
train_cfg
,
test_cfg
,
pretrained
)
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
self
.
use_grid_mask
=
use_grid_mask
self
.
fp16_enabled
=
False
# temporal
self
.
video_test_mode
=
video_test_mode
self
.
prev_frame_info
=
{
'prev_bev'
:
None
,
'scene_token'
:
None
,
'prev_pos'
:
0
,
'prev_angle'
:
0
,
}
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features of images."""
B
=
img
.
size
(
0
)
if
img
is
not
None
:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if
img
.
dim
()
==
5
and
img
.
size
(
0
)
==
1
:
img
.
squeeze_
()
elif
img
.
dim
()
==
5
and
img
.
size
(
0
)
>
1
:
B
,
N
,
C
,
H
,
W
=
img
.
size
()
img
=
img
.
reshape
(
B
*
N
,
C
,
H
,
W
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
img
=
img
.
to
(
memory_format
=
torch
.
channels_last
)
# img = img.contiguous()
# print("=======================",img.is_contiguous(),"=======================")
img_feats
=
self
.
img_backbone
(
img
)
if
isinstance
(
img_feats
,
dict
):
img_feats
=
list
(
img_feats
.
values
())
else
:
return
None
if
self
.
with_img_neck
:
img_feats
=
self
.
img_neck
(
img_feats
)
img_feats_reshaped
=
[]
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
if
len_queue
is
not
None
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
else
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
@
auto_fp16
(
apply_to
=
(
'img'
))
def
extract_feat
(
self
,
img
,
img_metas
=
None
,
len_queue
=
None
):
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
return
img_feats
def
forward_pts_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
,
prev_bev
=
None
):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs
=
self
.
pts_bbox_head
(
pts_feats
,
img_metas
,
prev_bev
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
pts_bbox_head
.
loss
(
*
loss_inputs
,
img_metas
=
img_metas
)
return
losses
def
forward_dummy
(
self
,
img
):
dummy_metas
=
None
return
self
.
forward_test
(
img
=
img
,
img_metas
=
[[
dummy_metas
]])
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self
.
eval
()
with
torch
.
no_grad
():
prev_bev
=
None
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
for
i
in
range
(
len_queue
):
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
if
not
img_metas
[
0
][
'prev_bev_exists'
]:
prev_bev
=
None
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats
=
[
each_scale
[:,
i
]
for
each_scale
in
img_feats_list
]
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
,
only_bev
=
True
)
self
.
train
()
return
prev_bev
@
auto_fp16
(
apply_to
=
(
'img'
,
'points'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
len_queue
=
img
.
size
(
1
)
prev_img
=
img
[:,
:
-
1
,
...]
img
=
img
[:,
-
1
,
...]
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
prev_bev
=
self
.
obtain_history_bev
(
prev_img
,
prev_img_metas
)
img_metas
=
[
each
[
len_queue
-
1
]
for
each
in
img_metas
]
if
not
img_metas
[
0
][
'prev_bev_exists'
]:
prev_bev
=
None
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
img_metas
,
img
=
None
,
**
kwargs
):
for
var
,
name
in
[(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
img
=
[
img
]
if
img
is
None
else
img
if
img_metas
[
0
][
0
][
'scene_token'
]
!=
self
.
prev_frame_info
[
'scene_token'
]:
# the first sample of each scene is truncated
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# update idx
self
.
prev_frame_info
[
'scene_token'
]
=
img_metas
[
0
][
0
][
'scene_token'
]
# do not use temporal information
if
not
self
.
video_test_mode
:
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# Get the delta of ego position and angle between two timestamps.
tmp_pos
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][:
3
])
tmp_angle
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][
-
1
])
if
self
.
prev_frame_info
[
'prev_bev'
]
is
not
None
:
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
-=
self
.
prev_frame_info
[
'prev_pos'
]
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
-=
self
.
prev_frame_info
[
'prev_angle'
]
else
:
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
=
0
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
=
0
new_prev_bev
,
bbox_results
=
self
.
simple_test
(
img_metas
[
0
],
img
[
0
],
prev_bev
=
self
.
prev_frame_info
[
'prev_bev'
],
**
kwargs
)
# During inference, we save the BEV features and ego motion of each timestamp.
self
.
prev_frame_info
[
'prev_pos'
]
=
tmp_pos
self
.
prev_frame_info
[
'prev_angle'
]
=
tmp_angle
self
.
prev_frame_info
[
'prev_bev'
]
=
new_prev_bev
return
bbox_results
def
simple_test_pts
(
self
,
x
,
img_metas
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function"""
outs
=
self
.
pts_bbox_head
(
x
,
img_metas
,
prev_bev
=
prev_bev
)
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
outs
,
img_metas
,
rescale
=
rescale
)
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
outs
[
'bev_embed'
],
bbox_results
def
simple_test
(
self
,
img_metas
,
img
=
None
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
bbox_list
=
[
dict
()
for
i
in
range
(
len
(
img_metas
))]
new_prev_bev
,
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
prev_bev
,
rescale
=
rescale
)
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
new_prev_bev
,
bbox_list
projects/mmdet3d_plugin/bevformer/detectors/bevformerV2.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
copy
from
collections
import
OrderedDict
import
torch
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
mmdet3d.models.builder
import
build_head
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
@
DETECTORS
.
register_module
()
class
BEVFormerV2
(
MVXTwoStageDetector
):
"""BEVFormer.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def
__init__
(
self
,
use_grid_mask
=
False
,
pts_voxel_layer
=
None
,
pts_voxel_encoder
=
None
,
pts_middle_encoder
=
None
,
pts_fusion_layer
=
None
,
img_backbone
=
None
,
pts_backbone
=
None
,
img_neck
=
None
,
pts_neck
=
None
,
pts_bbox_head
=
None
,
fcos3d_bbox_head
=
None
,
img_roi_head
=
None
,
img_rpn_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
video_test_mode
=
False
,
num_levels
=
None
,
num_mono_levels
=
None
,
mono_loss_weight
=
1.0
,
frames
=
(
0
,),
):
super
(
BEVFormerV2
,
self
).
__init__
(
pts_voxel_layer
,
pts_voxel_encoder
,
pts_middle_encoder
,
pts_fusion_layer
,
img_backbone
,
pts_backbone
,
img_neck
,
pts_neck
,
pts_bbox_head
,
img_roi_head
,
img_rpn_head
,
train_cfg
,
test_cfg
,
pretrained
)
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
self
.
use_grid_mask
=
use_grid_mask
self
.
fp16_enabled
=
False
assert
not
self
.
fp16_enabled
# not support fp16 yet
# temporal
self
.
video_test_mode
=
video_test_mode
assert
not
self
.
video_test_mode
# not support video_test_mode yet
# fcos3d head
self
.
fcos3d_bbox_head
=
build_head
(
fcos3d_bbox_head
)
if
fcos3d_bbox_head
else
None
# loss weight
self
.
mono_loss_weight
=
mono_loss_weight
# levels of features
self
.
num_levels
=
num_levels
self
.
num_mono_levels
=
num_mono_levels
self
.
frames
=
frames
def
extract_img_feat
(
self
,
img
):
"""Extract features of images."""
B
=
img
.
size
(
0
)
if
img
is
not
None
:
if
img
.
dim
()
==
5
and
img
.
size
(
0
)
==
1
:
img
.
squeeze_
()
elif
img
.
dim
()
==
5
and
img
.
size
(
0
)
>
1
:
B
,
N
,
C
,
H
,
W
=
img
.
size
()
img
=
img
.
reshape
(
B
*
N
,
C
,
H
,
W
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
img_feats
=
self
.
img_backbone
(
img
)
if
isinstance
(
img_feats
,
dict
):
img_feats
=
list
(
img_feats
.
values
())
else
:
return
None
if
self
.
with_img_neck
:
img_feats
=
self
.
img_neck
(
img_feats
)
img_feats_reshaped
=
[]
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
def
extract_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
)
if
'aug_param'
in
img_metas
[
0
]
and
img_metas
[
0
][
'aug_param'
][
'CropResizeFlipImage_param'
][
-
1
]
is
True
:
# flip feature
img_feats
=
[
torch
.
flip
(
x
,
dims
=
[
-
1
,
])
for
x
in
img_feats
]
return
img_feats
def
forward_pts_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
,
prev_bev
=
None
):
outs
=
self
.
pts_bbox_head
(
pts_feats
,
img_metas
,
prev_bev
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
pts_bbox_head
.
loss
(
*
loss_inputs
,
img_metas
=
img_metas
)
return
losses
def
forward_mono_train
(
self
,
img_feats
,
mono_input_dict
):
"""
img_feats (list[Tensor]): 5-D tensor for each level, (B, N, C, H, W)
gt_bboxes (list[list[Tensor]]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[list[Tensor]]): class indices corresponding to each box
gt_bboxes_3d (list[list[[Tensor]]): 3D boxes ground truth with shape of
(num_gts, code_size).
gt_labels_3d (list[list[Tensor]]): same as gt_labels
centers2d (list[list[Tensor]]): 2D centers on the image with shape of
(num_gts, 2).
depths (list[list[Tensor]]): Depth ground truth with shape of
(num_gts, ).
attr_labels (list[list[Tensor]]): Attributes indices of each box.
img_metas (list[list[dict]]): Meta information of each image, e.g.,
image size, scaling factor, etc.
ann_idx (list[list[idx]]): indicate which image has mono annotation.
"""
bsz
=
img_feats
[
0
].
shape
[
0
];
num_lvls
=
len
(
img_feats
)
img_feats_select
=
[[]
for
lvl
in
range
(
num_lvls
)]
for
lvl
,
img_feat
in
enumerate
(
img_feats
):
for
i
in
range
(
bsz
):
img_feats_select
[
lvl
].
append
(
img_feat
[
i
,
mono_input_dict
[
'mono_ann_idx'
][
i
]])
img_feats_select
[
lvl
]
=
torch
.
cat
(
img_feats_select
[
lvl
],
dim
=
0
)
bsz_new
=
img_feats_select
[
0
].
shape
[
0
]
assert
bsz
==
len
(
mono_input_dict
[
'mono_input_dict'
])
input_dict
=
[]
for
i
in
range
(
bsz
):
input_dict
.
extend
(
mono_input_dict
[
'mono_input_dict'
][
i
])
assert
bsz_new
==
len
(
input_dict
)
losses
=
self
.
fcos3d_bbox_head
.
forward_train
(
img_feats_select
,
input_dict
)
return
losses
def
forward_dummy
(
self
,
img
):
dummy_metas
=
None
return
self
.
forward_test
(
img
=
img
,
img_metas
=
[[
dummy_metas
]])
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
img_dict
,
img_metas_dict
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
# Modify: roll back to previous version for single frame
is_training
=
self
.
training
self
.
eval
()
prev_bev
=
OrderedDict
({
i
:
None
for
i
in
self
.
frames
})
with
torch
.
no_grad
():
for
t
in
img_dict
.
keys
():
img
=
img_dict
[
t
]
img_metas
=
[
img_metas_dict
[
t
],
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
if
self
.
num_levels
:
img_feats
=
img_feats
[:
self
.
num_levels
]
bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
None
,
only_bev
=
True
)
prev_bev
[
t
]
=
bev
.
detach
()
if
is_training
:
self
.
train
()
return
list
(
prev_bev
.
values
())
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
img
=
None
,
gt_bboxes_ignore
=
None
,
**
mono_input_dict
,
):
img_metas
=
OrderedDict
(
sorted
(
img_metas
[
0
].
items
()))
img_dict
=
{}
for
ind
,
t
in
enumerate
(
img_metas
.
keys
()):
img_dict
[
t
]
=
img
[:,
ind
,
...]
img
=
img_dict
[
0
]
img_dict
.
pop
(
0
)
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
prev_img_metas
.
pop
(
0
)
prev_bev
=
self
.
obtain_history_bev
(
img_dict
,
prev_img_metas
)
img_metas
=
[
img_metas
[
0
],
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
if
self
.
num_levels
is
None
else
img_feats
[:
self
.
num_levels
],
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
)
losses
.
update
(
losses_pts
)
if
self
.
fcos3d_bbox_head
:
losses_mono
=
self
.
forward_mono_train
(
img_feats
=
img_feats
if
self
.
num_mono_levels
is
None
else
img_feats
[:
self
.
num_mono_levels
],
mono_input_dict
=
mono_input_dict
)
for
k
,
v
in
losses_mono
.
items
():
losses
[
f
'
{
k
}
_mono'
]
=
v
*
self
.
mono_loss_weight
return
losses
def
forward_test
(
self
,
img_metas
,
img
=
None
,
**
kwargs
):
for
var
,
name
in
[(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
img
=
[
img
]
if
img
is
None
else
img
new_prev_bev
,
bbox_results
=
self
.
simple_test
(
img_metas
[
0
],
img
[
0
],
prev_bev
=
None
,
**
kwargs
)
return
bbox_results
def
simple_test_pts
(
self
,
x
,
img_metas
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function"""
outs
=
self
.
pts_bbox_head
(
x
,
img_metas
,
prev_bev
=
prev_bev
)
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
outs
,
img_metas
,
rescale
=
rescale
)
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
outs
[
'bev_embed'
],
bbox_results
def
simple_test
(
self
,
img_metas
,
img
=
None
,
prev_bev
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton."""
img_metas
=
OrderedDict
(
sorted
(
img_metas
[
0
].
items
()))
img_dict
=
{}
for
ind
,
t
in
enumerate
(
img_metas
.
keys
()):
img_dict
[
t
]
=
img
[:,
ind
,
...]
img
=
img_dict
[
0
]
img_dict
.
pop
(
0
)
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
prev_bev
=
self
.
obtain_history_bev
(
img_dict
,
prev_img_metas
)
img_metas
=
[
img_metas
[
0
],
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
if
self
.
num_levels
:
img_feats
=
img_feats
[:
self
.
num_levels
]
bbox_list
=
[
dict
()
for
i
in
range
(
len
(
img_metas
))]
new_prev_bev
,
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
prev_bev
,
rescale
=
rescale
)
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
new_prev_bev
,
bbox_list
projects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
tkinter.messagebox
import
NO
import
torch
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
from
projects.mmdet3d_plugin.bevformer.detectors.bevformer
import
BEVFormer
import
time
import
copy
import
numpy
as
np
import
mmdet3d
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
@
DETECTORS
.
register_module
()
class
BEVFormer_fp16
(
BEVFormer
):
"""
The default version BEVFormer currently can not support FP16.
We provide this version to resolve this issue.
"""
@
auto_fp16
(
apply_to
=
(
'img'
,
'prev_bev'
,
'points'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
prev_bev
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
=
prev_bev
)
losses
.
update
(
losses_pts
)
return
losses
def
val_step
(
self
,
data
,
optimizer
):
"""
In BEVFormer_fp16, we use this `val_step` function to inference the `prev_pev`.
This is not the standard function of `val_step`.
"""
img
=
data
[
'img'
]
img_metas
=
data
[
'img_metas'
]
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
prev_bev
=
data
.
get
(
'prev_bev'
,
None
)
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
=
prev_bev
,
only_bev
=
True
)
return
prev_bev
\ No newline at end of file
projects/mmdet3d_plugin/bevformer/hooks/__init__.py
0 → 100644
View file @
4cd43886
from
.custom_hooks
import
TransferWeight
\ No newline at end of file
projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
0 → 100644
View file @
4cd43886
from
mmcv.runner.hooks.hook
import
HOOKS
,
Hook
from
projects.mmdet3d_plugin.models.utils
import
run_time
@
HOOKS
.
register_module
()
class
TransferWeight
(
Hook
):
def
__init__
(
self
,
every_n_inters
=
1
):
self
.
every_n_inters
=
every_n_inters
def
after_train_iter
(
self
,
runner
):
if
self
.
every_n_inner_iters
(
runner
,
self
.
every_n_inters
):
runner
.
eval_model
.
load_state_dict
(
runner
.
model
.
state_dict
())
projects/mmdet3d_plugin/bevformer/modules/__init__.py
0 → 100644
View file @
4cd43886
from
.transformer
import
PerceptionTransformer
from
.transformerV2
import
PerceptionTransformerV2
,
PerceptionTransformerBEVEncoder
from
.spatial_cross_attention
import
SpatialCrossAttention
,
MSDeformableAttention3D
from
.temporal_self_attention
import
TemporalSelfAttention
from
.encoder
import
BEVFormerEncoder
,
BEVFormerLayer
from
.decoder
import
DetectionTransformerDecoder
from
.group_attention
import
GroupMultiheadAttention
projects/mmdet3d_plugin/bevformer/modules/custom_base_transformer_layer.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
copy
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv
import
ConfigDict
,
deprecated_api_warning
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
build_norm_layer
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
FEEDFORWARD_NETWORK
,
POSITIONAL_ENCODING
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
# noqa F401
warnings
.
warn
(
ImportWarning
(
'``MultiScaleDeformableAttention`` has been moved to '
'``mmcv.ops.multi_scale_deform_attn``, please change original path '
# noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` '
# noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` '
# noqa E501
))
except
ImportError
:
warnings
.
warn
(
'Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. '
)
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
,
build_attention
@
TRANSFORMER_LAYER
.
register_module
()
class
MyCustomBaseTransformerLayer
(
BaseModule
):
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def
__init__
(
self
,
attn_cfgs
=
None
,
ffn_cfgs
=
dict
(
type
=
'FFN'
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
ffn_drop
=
0.
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
),
operation_order
=
None
,
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
,
batch_first
=
True
,
**
kwargs
):
deprecated_args
=
dict
(
feedforward_channels
=
'feedforward_channels'
,
ffn_dropout
=
'ffn_drop'
,
ffn_num_fcs
=
'num_fcs'
)
for
ori_name
,
new_name
in
deprecated_args
.
items
():
if
ori_name
in
kwargs
:
warnings
.
warn
(
f
'The arguments `
{
ori_name
}
` in BaseTransformerLayer '
f
'has been deprecated, now you should set `
{
new_name
}
` '
f
'and other FFN related arguments '
f
'to a dict named `ffn_cfgs`. '
)
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
super
(
MyCustomBaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
self
.
batch_first
=
batch_first
assert
set
(
operation_order
)
&
set
(
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
set
(
operation_order
),
f
'The operation_order of'
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'contains all four operation type '
\
f
"
{
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
]
}
"
num_attn
=
operation_order
.
count
(
'self_attn'
)
+
operation_order
.
count
(
'cross_attn'
)
if
isinstance
(
attn_cfgs
,
dict
):
attn_cfgs
=
[
copy
.
deepcopy
(
attn_cfgs
)
for
_
in
range
(
num_attn
)]
else
:
assert
num_attn
==
len
(
attn_cfgs
),
f
'The length '
\
f
'of attn_cfg
{
num_attn
}
is '
\
f
'not consistent with the number of attention'
\
f
'in operation_order
{
operation_order
}
.'
self
.
num_attn
=
num_attn
self
.
operation_order
=
operation_order
self
.
norm_cfg
=
norm_cfg
self
.
pre_norm
=
operation_order
[
0
]
==
'norm'
self
.
attentions
=
ModuleList
()
index
=
0
for
operation_name
in
operation_order
:
if
operation_name
in
[
'self_attn'
,
'cross_attn'
]:
if
'batch_first'
in
attn_cfgs
[
index
]:
assert
self
.
batch_first
==
attn_cfgs
[
index
][
'batch_first'
]
else
:
attn_cfgs
[
index
][
'batch_first'
]
=
self
.
batch_first
attention
=
build_attention
(
attn_cfgs
[
index
])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention
.
operation_name
=
operation_name
self
.
attentions
.
append
(
attention
)
index
+=
1
self
.
embed_dims
=
self
.
attentions
[
0
].
embed_dims
self
.
ffns
=
ModuleList
()
num_ffns
=
operation_order
.
count
(
'ffn'
)
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
ConfigDict
(
ffn_cfgs
)
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
[
copy
.
deepcopy
(
ffn_cfgs
)
for
_
in
range
(
num_ffns
)]
assert
len
(
ffn_cfgs
)
==
num_ffns
for
ffn_index
in
range
(
num_ffns
):
if
'embed_dims'
not
in
ffn_cfgs
[
ffn_index
]:
ffn_cfgs
[
'embed_dims'
]
=
self
.
embed_dims
else
:
assert
ffn_cfgs
[
ffn_index
][
'embed_dims'
]
==
self
.
embed_dims
self
.
ffns
.
append
(
build_feedforward_network
(
ffn_cfgs
[
ffn_index
]))
self
.
norms
=
ModuleList
()
num_norms
=
operation_order
.
count
(
'norm'
)
for
_
in
range
(
num_norms
):
self
.
norms
.
append
(
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
if
layer
==
'self_attn'
:
temp_key
=
temp_value
=
query
query
=
self
.
attentions
[
attn_index
](
query
,
temp_key
,
temp_value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
query_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
projects/mmdet3d_plugin/bevformer/modules/decoder.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
mmcv
import
cv2
as
cv
import
copy
import
warnings
from
matplotlib
import
pyplot
as
plt
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
import
math
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
DetectionTransformerDecoder
(
TransformerLayerSequence
):
"""Implements the decoder in DETR3D transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
**
kwargs
):
super
(
DetectionTransformerDecoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
fp16_enabled
=
False
def
forward
(
self
,
query
,
*
args
,
reference_points
=
None
,
reg_branches
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
query
intermediate
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
reference_points
[...,
:
2
].
unsqueeze
(
2
)
# BS NUM_QUERY NUM_LEVEL 2
output
=
layer
(
output
,
*
args
,
reference_points
=
reference_points_input
,
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
assert
reference_points
.
shape
[
-
1
]
==
3
new_reference_points
=
torch
.
zeros_like
(
reference_points
)
new_reference_points
[...,
:
2
]
=
tmp
[
...,
:
2
]
+
inverse_sigmoid
(
reference_points
[...,
:
2
])
new_reference_points
[...,
2
:
3
]
=
tmp
[
...,
4
:
5
]
+
inverse_sigmoid
(
reference_points
[...,
2
:
3
])
new_reference_points
=
new_reference_points
.
sigmoid
()
reference_points
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
return
output
,
reference_points
@
ATTENTION
.
register_module
()
class
CustomMSDeformableAttention
(
BaseModule
):
"""An attention module used in Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
False
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiScaleDeformableAttention'
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
flag
=
'decoder'
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
# using fp16 deformable attention is unstable because it performs many sum operations
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
projects/mmdet3d_plugin/bevformer/modules/encoder.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
numpy
as
np
import
torch
import
copy
import
warnings
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
ext_loader
from
.custom_base_transformer_layer
import
MyCustomBaseTransformerLayer
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BEVFormerEncoder
(
TransformerLayerSequence
):
"""
Attention with both self and cross
Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
pc_range
=
None
,
num_points_in_pillar
=
4
,
return_intermediate
=
False
,
dataset_type
=
'nuscenes'
,
**
kwargs
):
super
(
BEVFormerEncoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
num_points_in_pillar
=
num_points_in_pillar
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
@
staticmethod
def
get_reference_points
(
H
,
W
,
Z
=
8
,
num_points_in_pillar
=
4
,
dim
=
'3d'
,
bs
=
1
,
device
=
'cuda'
,
dtype
=
torch
.
float
):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of bev.
Z: hight of pillar.
D: sample D points uniformly from each pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has
\
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
if
dim
==
'3d'
:
zs
=
torch
.
linspace
(
0.5
,
Z
-
0.5
,
num_points_in_pillar
,
dtype
=
dtype
,
device
=
device
).
view
(
-
1
,
1
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
Z
xs
=
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
1
,
W
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
W
ys
=
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
H
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
H
ref_3d
=
torch
.
stack
((
xs
,
ys
,
zs
),
-
1
)
ref_3d
=
ref_3d
.
permute
(
0
,
3
,
1
,
2
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
ref_3d
=
ref_3d
[
None
].
repeat
(
bs
,
1
,
1
,
1
)
return
ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif
dim
==
'2d'
:
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
),
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
)
)
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
H
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
W
ref_2d
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
ref_2d
=
ref_2d
.
repeat
(
bs
,
1
,
1
).
unsqueeze
(
2
)
return
ref_2d
# This function must use fp32!!!
@
force_fp32
(
apply_to
=
(
'reference_points'
,
'img_metas'
))
def
point_sampling
(
self
,
reference_points
,
pc_range
,
img_metas
):
# NOTE: close tf32 here.
allow_tf32
=
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
lidar2img
=
[]
for
img_meta
in
img_metas
:
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
lidar2img
=
np
.
asarray
(
lidar2img
)
lidar2img
=
reference_points
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
reference_points
=
reference_points
.
clone
()
reference_points
[...,
0
:
1
]
=
reference_points
[...,
0
:
1
]
*
\
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
]
reference_points
[...,
1
:
2
]
=
reference_points
[...,
1
:
2
]
*
\
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
]
reference_points
[...,
2
:
3
]
=
reference_points
[...,
2
:
3
]
*
\
(
pc_range
[
5
]
-
pc_range
[
2
])
+
pc_range
[
2
]
reference_points
=
torch
.
cat
(
(
reference_points
,
torch
.
ones_like
(
reference_points
[...,
:
1
])),
-
1
)
reference_points
=
reference_points
.
permute
(
1
,
0
,
2
,
3
)
D
,
B
,
num_query
=
reference_points
.
size
()[:
3
]
num_cam
=
lidar2img
.
size
(
1
)
reference_points
=
reference_points
.
view
(
D
,
B
,
1
,
num_query
,
4
).
repeat
(
1
,
1
,
num_cam
,
1
,
1
).
unsqueeze
(
-
1
)
lidar2img
=
lidar2img
.
view
(
1
,
B
,
num_cam
,
1
,
4
,
4
).
repeat
(
D
,
1
,
1
,
num_query
,
1
,
1
)
reference_points_cam
=
torch
.
matmul
(
lidar2img
.
to
(
torch
.
float32
),
reference_points
.
to
(
torch
.
float32
)).
squeeze
(
-
1
)
eps
=
1e-5
bev_mask
=
(
reference_points_cam
[...,
2
:
3
]
>
eps
)
reference_points_cam
=
reference_points_cam
[...,
0
:
2
]
/
torch
.
maximum
(
reference_points_cam
[...,
2
:
3
],
torch
.
ones_like
(
reference_points_cam
[...,
2
:
3
])
*
eps
)
reference_points_cam
[...,
0
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
1
]
reference_points_cam
[...,
1
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
0
]
bev_mask
=
(
bev_mask
&
(
reference_points_cam
[...,
1
:
2
]
>
0.0
)
&
(
reference_points_cam
[...,
1
:
2
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
>
0.0
))
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
bev_mask
=
torch
.
nan_to_num
(
bev_mask
)
else
:
bev_mask
=
bev_mask
.
new_tensor
(
np
.
nan_to_num
(
bev_mask
.
cpu
().
numpy
()))
reference_points_cam
=
reference_points_cam
.
permute
(
2
,
1
,
3
,
0
,
4
)
bev_mask
=
bev_mask
.
permute
(
2
,
1
,
3
,
0
,
4
).
squeeze
(
-
1
)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
allow_tf32
torch
.
backends
.
cudnn
.
allow_tf32
=
allow_tf32
return
reference_points_cam
,
bev_mask
@
auto_fp16
()
def
forward
(
self
,
bev_query
,
key
,
value
,
*
args
,
bev_h
=
None
,
bev_w
=
None
,
bev_pos
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
valid_ratios
=
None
,
prev_bev
=
None
,
shift
=
0.
,
**
kwargs
):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
bev_query
intermediate
=
[]
ref_3d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
],
self
.
num_points_in_pillar
,
dim
=
'3d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
ref_2d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
dim
=
'2d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
reference_points_cam
,
bev_mask
=
self
.
point_sampling
(
ref_3d
,
self
.
pc_range
,
kwargs
[
'img_metas'
])
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d
=
ref_2d
.
clone
()
shift_ref_2d
+=
shift
[:,
None
,
None
,
:]
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query
=
bev_query
.
permute
(
1
,
0
,
2
)
bev_pos
=
bev_pos
.
permute
(
1
,
0
,
2
)
bs
,
len_bev
,
num_bev_level
,
_
=
ref_2d
.
shape
if
prev_bev
is
not
None
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
prev_bev
=
torch
.
stack
(
[
prev_bev
,
bev_query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
-
1
)
hybird_ref_2d
=
torch
.
stack
([
shift_ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
else
:
hybird_ref_2d
=
torch
.
stack
([
ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
for
lid
,
layer
in
enumerate
(
self
.
layers
):
output
=
layer
(
bev_query
,
key
,
value
,
*
args
,
bev_pos
=
bev_pos
,
ref_2d
=
hybird_ref_2d
,
ref_3d
=
ref_3d
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
reference_points_cam
=
reference_points_cam
,
bev_mask
=
bev_mask
,
prev_bev
=
prev_bev
,
**
kwargs
)
bev_query
=
output
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
)
return
output
@
TRANSFORMER_LAYER
.
register_module
()
class
BEVFormerLayer
(
MyCustomBaseTransformerLayer
):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def
__init__
(
self
,
attn_cfgs
,
feedforward_channels
,
ffn_dropout
=
0.0
,
operation_order
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'LN'
),
ffn_num_fcs
=
2
,
**
kwargs
):
super
(
BEVFormerLayer
,
self
).
__init__
(
attn_cfgs
=
attn_cfgs
,
feedforward_channels
=
feedforward_channels
,
ffn_dropout
=
ffn_dropout
,
operation_order
=
operation_order
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
ffn_num_fcs
=
ffn_num_fcs
,
**
kwargs
)
self
.
fp16_enabled
=
False
assert
len
(
operation_order
)
==
6
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
bev_pos
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
ref_2d
=
None
,
ref_3d
=
None
,
bev_h
=
None
,
bev_w
=
None
,
reference_points_cam
=
None
,
mask
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
# temporal self attention
if
layer
==
'self_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
prev_bev
,
prev_bev
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
bev_pos
,
key_pos
=
bev_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
reference_points
=
ref_2d
,
spatial_shapes
=
torch
.
tensor
(
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
# spaital cross attention
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
reference_points
=
ref_3d
,
reference_points_cam
=
reference_points_cam
,
mask
=
mask
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
,
build_attention
@
TRANSFORMER_LAYER
.
register_module
()
class
MM_BEVFormerLayer
(
MyCustomBaseTransformerLayer
):
"""multi-modality fusion layer.
"""
def
__init__
(
self
,
attn_cfgs
,
feedforward_channels
,
ffn_dropout
=
0.0
,
operation_order
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'LN'
),
ffn_num_fcs
=
2
,
lidar_cross_attn_layer
=
None
,
**
kwargs
):
super
(
MM_BEVFormerLayer
,
self
).
__init__
(
attn_cfgs
=
attn_cfgs
,
feedforward_channels
=
feedforward_channels
,
ffn_dropout
=
ffn_dropout
,
operation_order
=
operation_order
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
ffn_num_fcs
=
ffn_num_fcs
,
**
kwargs
)
self
.
fp16_enabled
=
False
assert
len
(
operation_order
)
==
6
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
self
.
cross_model_weights
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
0.5
),
requires_grad
=
True
)
if
lidar_cross_attn_layer
:
self
.
lidar_cross_attn_layer
=
build_attention
(
lidar_cross_attn_layer
)
# self.cross_model_weights+=1
else
:
self
.
lidar_cross_attn_layer
=
None
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
bev_pos
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
ref_2d
=
None
,
ref_3d
=
None
,
bev_h
=
None
,
bev_w
=
None
,
reference_points_cam
=
None
,
mask
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
prev_bev
=
None
,
debug
=
False
,
depth
=
None
,
depth_z
=
None
,
lidar_bev
=
None
,
radar_bev
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
# temporal self attention
if
layer
==
'self_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
prev_bev
,
prev_bev
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
bev_pos
,
key_pos
=
bev_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
lidar_bev
=
lidar_bev
,
reference_points
=
ref_2d
,
spatial_shapes
=
torch
.
tensor
(
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
# spaital cross attention
elif
layer
==
'cross_attn'
:
new_query1
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
reference_points
=
ref_3d
,
reference_points_cam
=
reference_points_cam
,
mask
=
mask
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
depth
=
depth
,
lidar_bev
=
lidar_bev
,
depth_z
=
depth_z
,
**
kwargs
)
if
self
.
lidar_cross_attn_layer
:
bs
=
query
.
size
(
0
)
new_query2
=
self
.
lidar_cross_attn_layer
(
query
,
lidar_bev
,
lidar_bev
,
reference_points
=
ref_2d
[
bs
:],
spatial_shapes
=
torch
.
tensor
(
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
)
query
=
new_query1
*
self
.
cross_model_weights
+
(
1
-
self
.
cross_model_weights
)
*
new_query2
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
projects/mmdet3d_plugin/bevformer/modules/group_attention.py
0 → 100644
View file @
4cd43886
import
copy
import
math
import
warnings
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
(
Linear
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
)
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
FEEDFORWARD_NETWORK
,
POSITIONAL_ENCODING
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
@
ATTENTION
.
register_module
()
class
GroupMultiheadAttention
(
BaseModule
):
"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): When it is True, Key, Query and Value are shape of
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
Default to False.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
attn_drop
=
0.
,
proj_drop
=
0.
,
group
=
1
,
dropout_layer
=
dict
(
type
=
'Dropout'
,
drop_prob
=
0.
),
init_cfg
=
None
,
batch_first
=
False
,
**
kwargs
):
super
().
__init__
(
init_cfg
)
if
'dropout'
in
kwargs
:
warnings
.
warn
(
'The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) '
,
DeprecationWarning
)
attn_drop
=
kwargs
[
'dropout'
]
dropout_layer
[
'drop_prob'
]
=
kwargs
.
pop
(
'dropout'
)
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
self
.
group
=
group
self
.
batch_first
=
batch_first
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dims
,
num_heads
,
attn_drop
,
**
kwargs
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
dropout_layer
=
build_dropout
(
dropout_layer
)
if
dropout_layer
else
nn
.
Identity
()
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiheadAttention'
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_mask
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
identity (Tensor): This tensor, with the same shape as x,
will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Defaults to None.
Returns:
Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
if
key
is
None
:
key
=
query
if
value
is
None
:
value
=
key
if
identity
is
None
:
identity
=
query
if
key_pos
is
None
:
if
query_pos
is
not
None
:
# use query_pos if key_pos is not available
if
query_pos
.
shape
==
key
.
shape
:
key_pos
=
query_pos
else
:
warnings
.
warn
(
f
'position encoding of key is'
f
'missing in
{
self
.
__class__
.
__name__
}
.'
)
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
key_pos
is
not
None
:
key
=
key
+
key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if
self
.
batch_first
:
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
value
=
value
.
transpose
(
0
,
1
)
num_queries
=
query
.
shape
[
0
]
bs
=
query
.
shape
[
1
]
if
self
.
training
:
query
=
torch
.
cat
(
query
.
split
(
num_queries
//
self
.
group
,
dim
=
0
),
dim
=
1
)
key
=
torch
.
cat
(
key
.
split
(
num_queries
//
self
.
group
,
dim
=
0
),
dim
=
1
)
value
=
torch
.
cat
(
value
.
split
(
num_queries
//
self
.
group
,
dim
=
0
),
dim
=
1
)
out
=
self
.
attn
(
query
=
query
,
key
=
key
,
value
=
value
,
attn_mask
=
attn_mask
,
key_padding_mask
=
key_padding_mask
)[
0
]
if
self
.
training
:
out
=
torch
.
cat
(
out
.
split
(
bs
,
dim
=
1
),
dim
=
0
)
# shape
if
self
.
batch_first
:
out
=
out
.
transpose
(
0
,
1
)
return
identity
+
self
.
dropout_layer
(
self
.
proj_drop
(
out
))
projects/mmdet3d_plugin/bevformer/modules/multi_scale_deformable_attn_function.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
torch
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
class
MultiScaleDeformableAttnFunction_fp16
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
class
MultiScaleDeformableAttnFunction_fp32
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
ATTENTION
.
register_module
()
class
SpatialCrossAttention
(
BaseModule
):
"""An attention module used in BEVFormer.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_cams (int): The number of cameras
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
deformable_attention: (dict): The config for the deformable attention used in SCA.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_cams
=
6
,
pc_range
=
None
,
dropout
=
0.1
,
init_cfg
=
None
,
batch_first
=
False
,
deformable_attention
=
dict
(
type
=
'MSDeformableAttention3D'
,
embed_dims
=
256
,
num_levels
=
4
),
**
kwargs
):
super
(
SpatialCrossAttention
,
self
).
__init__
(
init_cfg
)
self
.
init_cfg
=
init_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
self
.
deformable_attention
=
build_attention
(
deformable_attention
)
self
.
embed_dims
=
embed_dims
self
.
num_cams
=
num_cams
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
batch_first
=
batch_first
self
.
init_weight
()
def
init_weight
(
self
):
"""Default initialization for Parameters of Module."""
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points_cam'
))
def
forward
(
self
,
query
,
key
,
value
,
residual
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
reference_points_cam
=
None
,
bev_mask
=
None
,
level_start_index
=
None
,
flag
=
'encoder'
,
**
kwargs
):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. (B, N, C, H, W)
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, 4),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
key
is
None
:
key
=
query
if
value
is
None
:
value
=
key
if
residual
is
None
:
inp_residual
=
query
slots
=
torch
.
zeros_like
(
query
)
if
query_pos
is
not
None
:
query
=
query
+
query_pos
bs
,
num_query
,
_
=
query
.
size
()
D
=
reference_points_cam
.
size
(
3
)
indexes
=
[]
for
i
,
mask_per_img
in
enumerate
(
bev_mask
):
index_query_per_img
=
mask_per_img
[
0
].
sum
(
-
1
).
nonzero
().
squeeze
(
-
1
)
indexes
.
append
(
index_query_per_img
)
max_len
=
max
([
len
(
each
)
for
each
in
indexes
])
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch
=
query
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
])
reference_points_rebatch
=
reference_points_cam
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
D
,
2
])
for
j
in
range
(
bs
):
for
i
,
reference_points_per_img
in
enumerate
(
reference_points_cam
):
index_query_per_img
=
indexes
[
i
]
queries_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
query
[
j
,
index_query_per_img
]
reference_points_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
reference_points_per_img
[
j
,
index_query_per_img
]
num_cams
,
l
,
bs
,
embed_dims
=
key
.
shape
key
=
key
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
value
=
value
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
queries
=
self
.
deformable_attention
(
query
=
queries_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
self
.
embed_dims
),
key
=
key
,
value
=
value
,
reference_points
=
reference_points_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
D
,
2
),
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
).
view
(
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
)
for
j
in
range
(
bs
):
for
i
,
index_query_per_img
in
enumerate
(
indexes
):
slots
[
j
,
index_query_per_img
]
+=
queries
[
j
,
i
,
:
len
(
index_query_per_img
)]
count
=
bev_mask
.
sum
(
-
1
)
>
0
count
=
count
.
permute
(
1
,
2
,
0
).
sum
(
-
1
)
count
=
torch
.
clamp
(
count
,
min
=
1.0
)
slots
=
slots
/
count
[...,
None
]
slots
=
self
.
output_proj
(
slots
)
return
self
.
dropout
(
slots
)
+
inp_residual
@
ATTENTION
.
register_module
()
class
MSDeformableAttention3D
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
8
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
batch_first
=
batch_first
self
.
output_proj
=
None
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
"""
For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
For each referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points.
"""
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
bs
,
num_query
,
num_Z_anchors
,
xy
=
reference_points
.
shape
reference_points
=
reference_points
[:,
:,
None
,
None
,
None
,
:,
:]
sampling_offsets
=
sampling_offsets
/
\
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
=
sampling_offsets
.
shape
sampling_offsets
=
sampling_offsets
.
view
(
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
//
num_Z_anchors
,
num_Z_anchors
,
xy
)
sampling_locations
=
reference_points
+
sampling_offsets
bs
,
num_query
,
num_heads
,
num_levels
,
num_points
,
num_Z_anchors
,
xy
=
sampling_locations
.
shape
assert
num_all_points
==
num_points
*
num_Z_anchors
sampling_locations
=
sampling_locations
.
view
(
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
)
elif
reference_points
.
shape
[
-
1
]
==
4
:
assert
False
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2
# attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
#
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
output
projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
ATTENTION
import
math
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
ATTENTION
.
register_module
()
class
TemporalSelfAttention
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to True.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV.
the length of BEV queue is 2.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
num_bev_queue
=
2
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
num_bev_queue
=
num_bev_queue
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
*
self
.
num_bev_queue
,
num_bev_queue
*
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
*
self
.
num_bev_queue
,
num_bev_queue
*
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
*
self
.
num_bev_queue
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
flag
=
'decoder'
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
assert
self
.
batch_first
bs
,
len_bev
,
c
=
query
.
shape
value
=
torch
.
stack
([
query
,
query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
c
)
# value = torch.cat([query, query], 0)
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
embed_dims
=
query
.
shape
_
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
assert
self
.
num_bev_queue
==
2
query
=
torch
.
cat
([
value
[:
bs
],
query
],
-
1
)
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
)
sampling_offsets
=
sampling_offsets
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
,
self
.
num_points
)
attention_weights
=
attention_weights
.
permute
(
0
,
3
,
1
,
2
,
4
,
5
)
\
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
).
contiguous
()
sampling_offsets
=
sampling_offsets
.
permute
(
0
,
3
,
1
,
2
,
4
,
5
,
6
)
\
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
# using fp16 deformable attention is unstable because it performs many sum operations
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
# output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
output
=
output
.
permute
(
1
,
2
,
0
)
# fuse history value and current value
# (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
output
=
output
.
view
(
num_query
,
embed_dims
,
bs
,
self
.
num_bev_queue
)
output
=
output
.
mean
(
-
1
)
# (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
output
=
output
.
permute
(
2
,
0
,
1
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
projects/mmdet3d_plugin/bevformer/modules/transformer.py
0 → 100644
View file @
4cd43886
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
from
torch.nn.init
import
normal_
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.runner.base_module
import
BaseModule
from
torchvision.transforms.functional
import
rotate
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
TRANSFORMER
.
register_module
()
class
PerceptionTransformer
(
BaseModule
):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
decoder
=
None
,
embed_dims
=
256
,
rotate_prev_bev
=
True
,
use_shift
=
True
,
use_can_bus
=
True
,
can_bus_norm
=
True
,
use_cams_embeds
=
True
,
rotate_center
=
[
100
,
100
],
**
kwargs
):
super
(
PerceptionTransformer
,
self
).
__init__
(
**
kwargs
)
self
.
encoder
=
build_transformer_layer_sequence
(
encoder
)
self
.
decoder
=
build_transformer_layer_sequence
(
decoder
)
self
.
embed_dims
=
embed_dims
self
.
num_feature_levels
=
num_feature_levels
self
.
num_cams
=
num_cams
self
.
fp16_enabled
=
False
self
.
rotate_prev_bev
=
rotate_prev_bev
self
.
use_shift
=
use_shift
self
.
use_can_bus
=
use_can_bus
self
.
can_bus_norm
=
can_bus_norm
self
.
use_cams_embeds
=
use_cams_embeds
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
init_layers
()
self
.
rotate_center
=
rotate_center
def
init_layers
(
self
):
"""Initialize layers of the Detr3DTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
cams_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_cams
,
self
.
embed_dims
))
self
.
reference_points
=
nn
.
Linear
(
self
.
embed_dims
,
3
)
self
.
can_bus_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
18
,
self
.
embed_dims
//
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
embed_dims
//
2
,
self
.
embed_dims
),
nn
.
ReLU
(
inplace
=
True
),
)
if
self
.
can_bus_norm
:
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
normal_
(
self
.
level_embeds
)
normal_
(
self
.
cams_embeds
)
xavier_init
(
self
.
reference_points
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
can_bus_mlp
,
distribution
=
'uniform'
,
bias
=
0.
)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'prev_bev'
,
'bev_pos'
))
def
get_bev_features
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""
obtain bev features.
"""
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_queries
=
bev_queries
.
unsqueeze
(
1
).
repeat
(
1
,
bs
,
1
)
bev_pos
=
bev_pos
.
flatten
(
2
).
permute
(
2
,
0
,
1
)
# obtain rotation angle and shift with ego motion
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
for
each
in
kwargs
[
'img_metas'
]])
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
for
each
in
kwargs
[
'img_metas'
]])
ego_angle
=
np
.
array
(
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
kwargs
[
'img_metas'
]])
grid_length_y
=
grid_length
[
0
]
grid_length_x
=
grid_length
[
1
]
translation_length
=
np
.
sqrt
(
delta_x
**
2
+
delta_y
**
2
)
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
bev_angle
=
ego_angle
-
translation_angle
shift_y
=
translation_length
*
\
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
bev_h
shift_x
=
translation_length
*
\
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
bev_w
shift_y
=
shift_y
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift
=
bev_queries
.
new_tensor
(
[
shift_x
,
shift_y
]).
permute
(
1
,
0
)
# xy, bs -> bs, xy
if
prev_bev
is
not
None
:
if
prev_bev
.
shape
[
1
]
==
bev_h
*
bev_w
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
if
self
.
rotate_prev_bev
:
for
i
in
range
(
bs
):
# num_prev_bev = prev_bev.size(1)
rotation_angle
=
kwargs
[
'img_metas'
][
i
][
'can_bus'
][
-
1
]
tmp_prev_bev
=
prev_bev
[:,
i
].
reshape
(
bev_h
,
bev_w
,
-
1
).
permute
(
2
,
0
,
1
)
tmp_prev_bev
=
rotate
(
tmp_prev_bev
,
rotation_angle
,
center
=
self
.
rotate_center
)
tmp_prev_bev
=
tmp_prev_bev
.
permute
(
1
,
2
,
0
).
reshape
(
bev_h
*
bev_w
,
1
,
-
1
)
prev_bev
[:,
i
]
=
tmp_prev_bev
[:,
0
]
# add can bus signals
can_bus
=
bev_queries
.
new_tensor
(
[
each
[
'can_bus'
]
for
each
in
kwargs
[
'img_metas'
]])
# [:, :]
can_bus
=
self
.
can_bus_mlp
(
can_bus
)[
None
,
:,
:]
bev_queries
=
bev_queries
+
can_bus
*
self
.
use_can_bus
feat_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
feat
in
enumerate
(
mlvl_feats
):
bs
,
num_cam
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
feat
=
feat
.
flatten
(
3
).
permute
(
1
,
0
,
3
,
2
)
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
2
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
bev_pos
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
feat_flatten
=
feat_flatten
.
permute
(
0
,
2
,
1
,
3
)
# (num_cam, H*W, bs, embed_dims)
bev_embed
=
self
.
encoder
(
bev_queries
,
feat_flatten
,
feat_flatten
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
prev_bev
,
shift
=
shift
,
**
kwargs
)
return
bev_embed
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'object_query_embed'
,
'prev_bev'
,
'bev_pos'
))
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
object_query_embed
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape
\
(num_dec_layers, bs, num_query, embed_dims), else has
\
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference
\
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference
\
points in decoder, has shape
\
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of
\
proposals generated from
\
encoder's feature maps, has shape
\
(batch, h*w, num_classes).
\
Only would be returned when `as_two_stage` is True,
\
otherwise None.
- enc_outputs_coord_unact: The regression results
\
generated from encoder's feature maps., has shape
\
(batch, h*w, 4). Only would
\
be returned when `as_two_stage` is True,
\
otherwise None.
"""
bev_embed
=
self
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
grid_length
,
bev_pos
=
bev_pos
,
prev_bev
=
prev_bev
,
**
kwargs
)
# bev_embed shape: bs, bev_h*bev_w, embed_dims
bs
=
mlvl_feats
[
0
].
size
(
0
)
query_pos
,
query
=
torch
.
split
(
object_query_embed
,
self
.
embed_dims
,
dim
=
1
)
query_pos
=
query_pos
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
query
=
query
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
reference_points
=
self
.
reference_points
(
query_pos
)
reference_points
=
reference_points
.
sigmoid
()
init_reference_out
=
reference_points
query
=
query
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
bev_embed
=
bev_embed
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
key
=
None
,
value
=
bev_embed
,
query_pos
=
query_pos
,
reference_points
=
reference_points
,
reg_branches
=
reg_branches
,
cls_branches
=
cls_branches
,
spatial_shapes
=
torch
.
tensor
([[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
inter_references_out
=
inter_references
return
bev_embed
,
inter_states
,
init_reference_out
,
inter_references_out
projects/mmdet3d_plugin/bevformer/modules/transformerV2.py
0 → 100644
View file @
4cd43886
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
from
mmdet.models.utils.builder
import
TRANSFORMER
from
torch.nn.init
import
normal_
from
mmcv.runner.base_module
import
BaseModule
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
mmcv.cnn
import
build_norm_layer
,
build_conv_layer
import
torch.utils.checkpoint
as
checkpoint
from
mmdet.models.backbones.resnet
import
Bottleneck
,
BasicBlock
class
ResNetFusion
(
BaseModule
):
def
__init__
(
self
,
in_channels
,
out_channels
,
inter_channels
,
num_layer
,
norm_cfg
=
dict
(
type
=
'SyncBN'
),
with_cp
=
False
):
super
(
ResNetFusion
,
self
).
__init__
()
layers
=
[]
self
.
inter_channels
=
inter_channels
for
i
in
range
(
num_layer
):
if
i
==
0
:
if
inter_channels
==
in_channels
:
layers
.
append
(
BasicBlock
(
in_channels
,
inter_channels
,
stride
=
1
,
norm_cfg
=
norm_cfg
))
else
:
downsample
=
nn
.
Sequential
(
build_conv_layer
(
None
,
in_channels
,
inter_channels
,
3
,
stride
=
1
,
padding
=
1
,
dilation
=
1
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
inter_channels
)[
1
])
layers
.
append
(
BasicBlock
(
in_channels
,
inter_channels
,
stride
=
1
,
norm_cfg
=
norm_cfg
,
downsample
=
downsample
))
else
:
layers
.
append
(
BasicBlock
(
inter_channels
,
inter_channels
,
stride
=
1
,
norm_cfg
=
norm_cfg
))
self
.
layers
=
nn
.
Sequential
(
*
layers
)
self
.
layer_norm
=
nn
.
Sequential
(
nn
.
Linear
(
inter_channels
,
out_channels
),
nn
.
LayerNorm
(
out_channels
))
self
.
with_cp
=
with_cp
def
forward
(
self
,
x
):
x
=
torch
.
cat
(
x
,
1
).
contiguous
()
# x should be [1, in_channels, bev_h, bev_w]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
checkpoint
.
checkpoint
(
layer
,
x
)
else
:
x
=
layer
(
x
)
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
-
1
).
permute
(
0
,
2
,
1
)
# nchw -> n(hw)c
x
=
self
.
layer_norm
(
x
)
return
x
@
TRANSFORMER
.
register_module
()
class
PerceptionTransformerBEVEncoder
(
BaseModule
):
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
embed_dims
=
256
,
use_cams_embeds
=
True
,
rotate_center
=
[
100
,
100
],
**
kwargs
):
super
(
PerceptionTransformerBEVEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
encoder
=
build_transformer_layer_sequence
(
encoder
)
self
.
embed_dims
=
embed_dims
self
.
num_feature_levels
=
num_feature_levels
self
.
num_cams
=
num_cams
self
.
fp16_enabled
=
False
self
.
use_cams_embeds
=
use_cams_embeds
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
rotate_center
=
rotate_center
"""Initialize layers of the Detr3DTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
if
self
.
use_cams_embeds
:
self
.
cams_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_cams
,
self
.
embed_dims
))
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
normal_
(
self
.
level_embeds
)
if
self
.
use_cams_embeds
:
normal_
(
self
.
cams_embeds
)
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""
obtain bev features.
"""
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_queries
=
bev_queries
.
unsqueeze
(
1
).
repeat
(
1
,
bs
,
1
)
bev_pos
=
bev_pos
.
flatten
(
2
).
permute
(
2
,
0
,
1
)
feat_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
feat
in
enumerate
(
mlvl_feats
):
bs
,
num_cam
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
feat
=
feat
.
flatten
(
3
).
permute
(
1
,
0
,
3
,
2
)
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
2
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
bev_pos
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
((
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
feat_flatten
=
feat_flatten
.
permute
(
0
,
2
,
1
,
3
)
# (num_cam, H*W, bs, embed_dims)
bev_embed
=
self
.
encoder
(
bev_queries
,
feat_flatten
,
feat_flatten
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
None
,
shift
=
bev_queries
.
new_tensor
([
0
,
0
]).
unsqueeze
(
0
),
**
kwargs
)
# rotate current bev to final aligned
prev_bev
=
bev_embed
if
'aug_param'
in
kwargs
[
'img_metas'
][
0
]
and
'GlobalRotScaleTransImage_param'
in
kwargs
[
'img_metas'
][
0
][
'aug_param'
]:
rot_angle
,
scale_ratio
,
flip_dx
,
flip_dy
,
bda_mat
,
only_gt
=
kwargs
[
'img_metas'
][
0
][
'aug_param'
][
'GlobalRotScaleTransImage_param'
]
prev_bev
=
prev_bev
.
reshape
(
bs
,
bev_h
,
bev_w
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
# bchw
if
only_gt
:
# rot angle
# prev_bev = torchvision.transforms.functional.rotate(prev_bev, -30, InterpolationMode.BILINEAR)
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
bev_h
-
0.5
,
bev_h
,
dtype
=
bev_queries
.
dtype
,
device
=
bev_queries
.
device
),
torch
.
linspace
(
0.5
,
bev_w
-
0.5
,
bev_w
,
dtype
=
bev_queries
.
dtype
,
device
=
bev_queries
.
device
))
ref_y
=
(
ref_y
/
bev_h
)
ref_x
=
(
ref_x
/
bev_w
)
grid
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
grid_shift
=
grid
*
2.0
-
1.0
grid_shift
=
grid_shift
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)
# bda_mat = ( bda_mat[:2, :2] / scale_ratio).to(grid_shift).view(1, 1, 1, 2,2).repeat(grid_shift.shape[0], grid_shift.shape[1], grid_shift.shape[2], 1, 1)
bda_mat
=
bda_mat
[:
2
,
:
2
].
to
(
grid_shift
).
view
(
1
,
1
,
1
,
2
,
2
).
repeat
(
grid_shift
.
shape
[
0
],
grid_shift
.
shape
[
1
],
grid_shift
.
shape
[
2
],
1
,
1
)
grid_shift
=
torch
.
matmul
(
bda_mat
,
grid_shift
).
squeeze
(
-
1
)
# grid_shift = grid_shift / scale_ratio
prev_bev
=
torch
.
nn
.
functional
.
grid_sample
(
prev_bev
,
grid_shift
,
align_corners
=
False
)
# if flip_dx:
# prev_bev = torch.flip(prev_bev, dims=[-1])
# if flip_dy:
# prev_bev = torch.flip(prev_bev, dims=[-2])
prev_bev
=
prev_bev
.
reshape
(
bs
,
-
1
,
bev_h
*
bev_w
)
prev_bev
=
prev_bev
.
permute
(
0
,
2
,
1
)
return
prev_bev
@
TRANSFORMER
.
register_module
()
class
PerceptionTransformerV2
(
PerceptionTransformerBEVEncoder
):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
embed_dims
=
256
,
use_cams_embeds
=
True
,
rotate_center
=
[
100
,
100
],
frames
=
(
0
,),
decoder
=
None
,
num_fusion
=
3
,
inter_channels
=
None
,
**
kwargs
):
super
(
PerceptionTransformerV2
,
self
).
__init__
(
num_feature_levels
,
num_cams
,
two_stage_num_proposals
,
encoder
,
embed_dims
,
use_cams_embeds
,
rotate_center
,
**
kwargs
)
self
.
decoder
=
build_transformer_layer_sequence
(
decoder
)
"""Initialize layers of the Detr3DTransformer."""
self
.
reference_points
=
nn
.
Linear
(
self
.
embed_dims
,
3
)
self
.
frames
=
frames
if
len
(
self
.
frames
)
>
1
:
self
.
fusion
=
ResNetFusion
(
len
(
self
.
frames
)
*
self
.
embed_dims
,
self
.
embed_dims
,
inter_channels
if
inter_channels
is
not
None
else
len
(
self
.
frames
)
*
self
.
embed_dims
,
num_fusion
)
def
init_weights
(
self
):
"""Initialize the transformer weights."""
super
().
init_weights
()
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
xavier_init
(
self
.
reference_points
,
distribution
=
'uniform'
,
bias
=
0.
)
def
get_bev_features
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
return
super
().
forward
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
,
bev_pos
,
prev_bev
,
**
kwargs
)
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
object_query_embed
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape
\
(num_dec_layers, bs, num_query, embed_dims), else has
\
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference
\
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference
\
points in decoder, has shape
\
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of
\
proposals generated from
\
encoder's feature maps, has shape
\
(batch, h*w, num_classes).
\
Only would be returned when `as_two_stage` is True,
\
otherwise None.
- enc_outputs_coord_unact: The regression results
\
generated from encoder's feature maps., has shape
\
(batch, h*w, 4). Only would
\
be returned when `as_two_stage` is True,
\
otherwise None.
"""
bev_embed
=
self
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
grid_length
,
bev_pos
=
bev_pos
,
prev_bev
=
None
,
**
kwargs
)
# bev_embed shape: bs, bev_h*bev_w, embed_dims
if
len
(
self
.
frames
)
>
1
:
cur_ind
=
list
(
self
.
frames
).
index
(
0
)
assert
prev_bev
[
cur_ind
]
is
None
and
len
(
prev_bev
)
==
len
(
self
.
frames
)
prev_bev
[
cur_ind
]
=
bev_embed
# fill prev frame feature
for
i
in
range
(
1
,
cur_ind
+
1
):
if
prev_bev
[
cur_ind
-
i
]
is
None
:
prev_bev
[
cur_ind
-
i
]
=
prev_bev
[
cur_ind
-
i
+
1
].
detach
()
# fill next frame feature
for
i
in
range
(
cur_ind
+
1
,
len
(
self
.
frames
)):
if
prev_bev
[
i
]
is
None
:
prev_bev
[
i
]
=
prev_bev
[
i
-
1
].
detach
()
bev_embed
=
[
x
.
reshape
(
x
.
shape
[
0
],
bev_h
,
bev_w
,
x
.
shape
[
-
1
]).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
for
x
in
prev_bev
]
bev_embed
=
self
.
fusion
(
bev_embed
)
bs
=
mlvl_feats
[
0
].
size
(
0
)
query_pos
,
query
=
torch
.
split
(
object_query_embed
,
self
.
embed_dims
,
dim
=
1
)
query_pos
=
query_pos
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
query
=
query
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
reference_points
=
self
.
reference_points
(
query_pos
)
reference_points
=
reference_points
.
sigmoid
()
init_reference_out
=
reference_points
query
=
query
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
bev_embed
=
bev_embed
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
key
=
None
,
value
=
bev_embed
,
query_pos
=
query_pos
,
reference_points
=
reference_points
,
reg_branches
=
reg_branches
,
cls_branches
=
cls_branches
,
spatial_shapes
=
torch
.
tensor
([[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
inter_references_out
=
inter_references
return
bev_embed
,
inter_states
,
init_reference_out
,
inter_references_out
projects/mmdet3d_plugin/bevformer/runner/__init__.py
0 → 100644
View file @
4cd43886
from
.epoch_based_runner
import
EpochBasedRunner_video
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment