Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
b64d9ca3
Unverified
Commit
b64d9ca3
authored
Apr 17, 2023
by
Wenhai Wang
Committed by
GitHub
Apr 17, 2023
Browse files
Merge pull request #105 from zhiqi-li/occupancy
support occupancy prediction
parents
bdd98bcb
df3c64a9
Changes
160
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3400 additions
and
0 deletions
+3400
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/src/vision.cpp
...det3d_plugin/bevformer/backbones/ops_dcnv3/src/vision.cpp
+17
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/test.py
...ects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/test.py
+263
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
...projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_occ_head.py
...mdet3d_plugin/bevformer/dense_heads/bevformer_occ_head.py
+207
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/__init__.py
...n/projects/mmdet3d_plugin/bevformer/detectors/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/bevformer_occ.py
...jects/mmdet3d_plugin/bevformer/detectors/bevformer_occ.py
+298
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/__init__.py
...ction/projects/mmdet3d_plugin/bevformer/hooks/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
...n/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
+14
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/__init__.py
...ion/projects/mmdet3d_plugin/bevformer/modules/__init__.py
+6
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/custom_base_transformer_layer.py
...plugin/bevformer/modules/custom_base_transformer_layer.py
+262
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/decoder.py
...tion/projects/mmdet3d_plugin/bevformer/modules/decoder.py
+345
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/encoder.py
...tion/projects/mmdet3d_plugin/bevformer/modules/encoder.py
+408
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/multi_scale_deformable_attn_function.py
...bevformer/modules/multi_scale_deformable_attn_function.py
+163
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
...det3d_plugin/bevformer/modules/spatial_cross_attention.py
+400
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
...det3d_plugin/bevformer/modules/temporal_self_attention.py
+272
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer.py
.../projects/mmdet3d_plugin/bevformer/modules/transformer.py
+289
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer_occ.py
...jects/mmdet3d_plugin/bevformer/modules/transformer_occ.py
+352
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/__init__.py
...tion/projects/mmdet3d_plugin/bevformer/runner/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py
...cts/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py
+97
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/core/bbox/assigners/__init__.py
...n/projects/mmdet3d_plugin/core/bbox/assigners/__init__.py
+3
-0
No files found.
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/src/vision.cpp
0 → 100644
View file @
b64d9ca3
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "dcnv3.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dcnv3_forward"
,
&
dcnv3_forward
,
"dcnv3_forward"
);
m
.
def
(
"dcnv3_backward"
,
&
dcnv3_backward
,
"dcnv3_backward"
);
}
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/test.py
0 → 100644
View file @
b64d9ca3
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
time
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
H_in
,
W_in
=
8
,
8
N
,
M
,
D
=
2
,
4
,
16
Kh
,
Kw
=
3
,
3
P
=
Kh
*
Kw
offset_scale
=
2.0
pad
=
1
dilation
=
1
stride
=
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
torch
.
manual_seed
(
3
)
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_double
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
output_pytorch
=
dcnv3_core_pytorch
(
input
.
double
(),
offset
.
double
(),
mask
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
).
detach
().
cpu
()
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input
.
double
(),
offset
.
double
(),
mask
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
).
detach
().
cpu
()
fwdok
=
torch
.
allclose
(
output_cuda
,
output_pytorch
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
print
(
'>>> forward double'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_float
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
output_pytorch
=
dcnv3_core_pytorch
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
).
detach
().
cpu
()
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
).
detach
().
cpu
()
fwdok
=
torch
.
allclose
(
output_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
print
(
'>>> forward float'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
def
check_backward_equal_with_pytorch_double
(
channels
=
4
,
grad_input
=
True
,
grad_offset
=
True
,
grad_mask
=
True
):
# H_in, W_in = 4, 4
N
=
2
M
=
2
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
D
=
channels
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
input0
.
requires_grad
=
grad_input
offset0
.
requires_grad
=
grad_offset
mask0
.
requires_grad
=
grad_mask
output_pytorch
=
dcnv3_core_pytorch
(
input0
.
double
(),
offset0
.
double
(),
mask0
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
)
output_pytorch
.
sum
().
backward
()
input1
=
input0
.
detach
()
offset1
=
offset0
.
detach
()
mask1
=
mask0
.
detach
()
input1
.
requires_grad
=
grad_input
offset1
.
requires_grad
=
grad_offset
mask1
.
requires_grad
=
grad_mask
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input1
.
double
(),
offset1
.
double
(),
mask1
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
)
output_cuda
.
sum
().
backward
()
print
(
f
'>>> backward double: channels
{
D
}
'
)
bwdok
=
torch
.
allclose
(
input0
.
grad
,
input1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
input0
.
grad
-
input1
.
grad
).
abs
().
max
()
max_rel_err
=
((
input0
.
grad
-
input1
.
grad
).
abs
()
/
input0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
input_grad check_backward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
offset0
.
grad
,
offset1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
offset0
.
grad
-
offset1
.
grad
).
abs
().
max
()
max_rel_err
=
((
offset0
.
grad
-
offset1
.
grad
).
abs
()
/
offset0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
offset_grad check_backward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
mask0
.
grad
,
mask1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
mask0
.
grad
-
mask1
.
grad
).
abs
().
max
()
max_rel_err
=
((
mask0
.
grad
-
mask1
.
grad
).
abs
()
/
mask0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
mask_grad check_backward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
def
check_backward_equal_with_pytorch_float
(
channels
=
4
,
grad_input
=
True
,
grad_offset
=
True
,
grad_mask
=
True
):
# H_in, W_in = 4, 4
N
=
2
M
=
2
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
D
=
channels
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
input0
.
requires_grad
=
grad_input
offset0
.
requires_grad
=
grad_offset
mask0
.
requires_grad
=
grad_mask
output_pytorch
=
dcnv3_core_pytorch
(
input0
,
offset0
,
mask0
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
)
output_pytorch
.
sum
().
backward
()
input1
=
input0
.
detach
()
offset1
=
offset0
.
detach
()
mask1
=
mask0
.
detach
()
input1
.
requires_grad
=
grad_input
offset1
.
requires_grad
=
grad_offset
mask1
.
requires_grad
=
grad_mask
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input1
,
offset1
,
mask1
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
)
output_cuda
.
sum
().
backward
()
print
(
f
'>>> backward float: channels
{
D
}
'
)
bwdok
=
torch
.
allclose
(
input0
.
grad
,
input1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
input0
.
grad
-
input1
.
grad
).
abs
().
max
()
max_rel_err
=
((
input0
.
grad
-
input1
.
grad
).
abs
()
/
input0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
input_grad check_backward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
offset0
.
grad
,
offset1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
offset0
.
grad
-
offset1
.
grad
).
abs
().
max
()
max_rel_err
=
((
offset0
.
grad
-
offset1
.
grad
).
abs
()
/
offset0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
offset_grad check_backward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
mask0
.
grad
,
mask1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
mask0
.
grad
-
mask1
.
grad
).
abs
().
max
()
max_rel_err
=
((
mask0
.
grad
-
mask1
.
grad
).
abs
()
/
mask0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
mask_grad check_backward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
@
torch
.
no_grad
()
def
check_time_cost
(
im2col_step
=
128
):
N
=
512
H_in
,
W_in
=
64
,
64
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
print
(
f
'>>> time cost: im2col_step
{
im2col_step
}
; input
{
input
.
shape
}
; points
{
P
}
'
)
repeat
=
100
for
i
in
range
(
repeat
):
output_cuda
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
1.0
,
im2col_step
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
repeat
):
output_cuda
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
1.0
,
im2col_step
)
torch
.
cuda
.
synchronize
()
print
(
f
'foward time cost:
{
(
time
.
time
()
-
start
)
/
repeat
}
'
)
if
__name__
==
'__main__'
:
check_forward_equal_with_pytorch_double
()
check_forward_equal_with_pytorch_float
()
for
channels
in
[
1
,
16
,
30
,
32
,
64
,
71
,
1025
]:
check_backward_equal_with_pytorch_double
(
channels
,
True
,
True
,
True
)
for
channels
in
[
1
,
16
,
30
,
32
,
64
,
71
,
1025
]:
check_backward_equal_with_pytorch_float
(
channels
,
True
,
True
,
True
)
for
i
in
range
(
3
):
im2col_step
=
128
*
(
2
**
i
)
check_time_cost
(
im2col_step
)
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
0 → 100644
View file @
b64d9ca3
from
.bevformer_occ_head
import
BEVFormerOccHead
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_occ_head.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Xiaoyu Tian
# ---------------------------------------------
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmdet.core
import
(
multi_apply
,
multi_apply
,
reduce_mean
)
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
HEADS
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet3d.core.bbox.coders
import
build_bbox_coder
from
projects.mmdet3d_plugin.core.bbox.util
import
normalize_bbox
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
import
numpy
as
np
import
mmcv
import
cv2
as
cv
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.builder
import
build_loss
from
mmcv.runner
import
BaseModule
,
force_fp32
@
HEADS
.
register_module
()
class
BEVFormerOccHead
(
BaseModule
):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def
__init__
(
self
,
*
args
,
with_box_refine
=
False
,
as_two_stage
=
False
,
transformer
=
None
,
bbox_coder
=
None
,
num_cls_fcs
=
2
,
code_weights
=
None
,
pc_range
=
[
-
40
,
-
40
,
-
1.0
,
40
,
40
,
5.4
],
bev_h
=
30
,
bev_w
=
30
,
loss_occ
=
None
,
use_mask
=
False
,
positional_encoding
=
None
,
**
kwargs
):
self
.
bev_h
=
bev_h
self
.
bev_w
=
bev_w
self
.
fp16_enabled
=
False
self
.
num_classes
=
kwargs
[
'num_classes'
]
self
.
use_mask
=
use_mask
self
.
with_box_refine
=
with_box_refine
self
.
as_two_stage
=
as_two_stage
if
self
.
as_two_stage
:
transformer
[
'as_two_stage'
]
=
self
.
as_two_stage
self
.
pc_range
=
pc_range
self
.
real_w
=
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
]
self
.
real_h
=
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
]
self
.
num_cls_fcs
=
num_cls_fcs
-
1
super
(
BEVFormerOccHead
,
self
).
__init__
()
self
.
loss_occ
=
build_loss
(
loss_occ
)
self
.
positional_encoding
=
build_positional_encoding
(
positional_encoding
)
self
.
transformer
=
build_transformer
(
transformer
)
self
.
embed_dims
=
self
.
transformer
.
embed_dims
if
not
self
.
as_two_stage
:
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
self
.
transformer
.
init_weights
()
# if self.loss_cls.use_sigmoid:
# bias_init = bias_init_with_prob(0.01)
# for m in self.cls_branches:
# nn.init.constant_(m[-1].bias, bias_init)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
))
def
forward
(
self
,
mlvl_feats
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
,
test
=
False
):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
"""
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
object_query_embeds
=
None
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
if
only_bev
:
# only use encoder to obtain BEV features, TODO: refine the workaround
return
self
.
transformer
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)
else
:
outputs
=
self
.
transformer
(
mlvl_feats
,
bev_queries
,
object_query_embeds
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
reg_branches
=
None
,
# noqa:E501
cls_branches
=
None
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
)
bev_embed
,
occ_outs
=
outputs
outs
=
{
'bev_embed'
:
bev_embed
,
'occ'
:
occ_outs
,
}
return
outs
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
loss
(
self
,
# gt_bboxes_list,
# gt_labels_list,
voxel_semantics
,
mask_camera
,
preds_dicts
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
loss_dict
=
dict
()
occ
=
preds_dicts
[
'occ'
]
assert
voxel_semantics
.
min
()
>=
0
and
voxel_semantics
.
max
()
<=
17
losses
=
self
.
loss_single
(
voxel_semantics
,
mask_camera
,
occ
)
loss_dict
[
'loss_occ'
]
=
losses
return
loss_dict
def
loss_single
(
self
,
voxel_semantics
,
mask_camera
,
preds
):
voxel_semantics
=
voxel_semantics
.
long
()
if
self
.
use_mask
:
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
,
self
.
num_classes
)
mask_camera
=
mask_camera
.
reshape
(
-
1
)
num_total_samples
=
mask_camera
.
sum
()
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,
mask_camera
,
avg_factor
=
num_total_samples
)
else
:
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
,
self
.
num_classes
)
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,)
return
loss_occ
@
force_fp32
(
apply_to
=
(
'preds'
))
def
get_occ
(
self
,
preds_dicts
,
img_metas
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
predss : occ results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
# return self.transformer.get_occ(
# preds_dicts, img_metas, rescale=rescale)
# print(img_metas[0].keys())
occ_out
=
preds_dicts
[
'occ'
]
occ_score
=
occ_out
.
softmax
(
-
1
)
occ_score
=
occ_score
.
argmax
(
-
1
)
return
occ_score
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/__init__.py
0 → 100644
View file @
b64d9ca3
from
.bevformer_occ
import
BEVFormerOcc
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/bevformer_occ.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Xiaoyu Tian
# ---------------------------------------------
import
torch
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
import
time
import
copy
import
numpy
as
np
import
mmdet3d
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
@
DETECTORS
.
register_module
()
class
BEVFormerOcc
(
MVXTwoStageDetector
):
"""BEVFormer.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def
__init__
(
self
,
use_grid_mask
=
False
,
pts_voxel_layer
=
None
,
pts_voxel_encoder
=
None
,
pts_middle_encoder
=
None
,
pts_fusion_layer
=
None
,
img_backbone
=
None
,
pts_backbone
=
None
,
img_neck
=
None
,
pts_neck
=
None
,
pts_bbox_head
=
None
,
img_roi_head
=
None
,
img_rpn_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
video_test_mode
=
False
):
super
(
BEVFormerOcc
,
self
).
__init__
(
pts_voxel_layer
,
pts_voxel_encoder
,
pts_middle_encoder
,
pts_fusion_layer
,
img_backbone
,
pts_backbone
,
img_neck
,
pts_neck
,
pts_bbox_head
,
img_roi_head
,
img_rpn_head
,
train_cfg
,
test_cfg
,
pretrained
)
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
self
.
use_grid_mask
=
use_grid_mask
self
.
fp16_enabled
=
False
# temporal
self
.
video_test_mode
=
video_test_mode
self
.
prev_frame_info
=
{
'prev_bev'
:
None
,
'scene_token'
:
None
,
'prev_pos'
:
0
,
'prev_angle'
:
0
,
}
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features of images."""
B
=
img
.
size
(
0
)
if
img
is
not
None
:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if
img
.
dim
()
==
5
and
img
.
size
(
0
)
==
1
:
img
.
squeeze_
()
elif
img
.
dim
()
==
5
and
img
.
size
(
0
)
>
1
:
B
,
N
,
C
,
H
,
W
=
img
.
size
()
img
=
img
.
reshape
(
B
*
N
,
C
,
H
,
W
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
img_feats
=
self
.
img_backbone
(
img
)
if
isinstance
(
img_feats
,
dict
):
img_feats
=
list
(
img_feats
.
values
())
else
:
return
None
if
self
.
with_img_neck
:
img_feats
=
self
.
img_neck
(
img_feats
)
img_feats_reshaped
=
[]
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
if
len_queue
is
not
None
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
else
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
@
auto_fp16
(
apply_to
=
(
'img'
))
def
extract_feat
(
self
,
img
,
img_metas
=
None
,
len_queue
=
None
):
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
return
img_feats
def
forward_pts_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
voxel_semantics
,
mask_camera
,
img_metas
,
gt_bboxes_ignore
=
None
,
prev_bev
=
None
):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs
=
self
.
pts_bbox_head
(
pts_feats
,
img_metas
,
prev_bev
)
loss_inputs
=
[
voxel_semantics
,
mask_camera
,
outs
]
losses
=
self
.
pts_bbox_head
.
loss
(
*
loss_inputs
,
img_metas
=
img_metas
)
return
losses
def
forward_dummy
(
self
,
img
):
dummy_metas
=
None
return
self
.
forward_test
(
img
=
img
,
img_metas
=
[[
dummy_metas
]])
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self
.
eval
()
with
torch
.
no_grad
():
prev_bev
=
None
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
for
i
in
range
(
len_queue
):
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
if
not
img_metas
[
0
][
'prev_bev_exists'
]:
prev_bev
=
None
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats
=
[
each_scale
[:,
i
]
for
each_scale
in
img_feats_list
]
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
,
only_bev
=
True
)
self
.
train
()
return
prev_bev
@
auto_fp16
(
apply_to
=
(
'img'
,
'points'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
voxel_semantics
=
None
,
mask_lidar
=
None
,
mask_camera
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
len_queue
=
img
.
size
(
1
)
prev_img
=
img
[:,
:
-
1
,
...]
img
=
img
[:,
-
1
,
...]
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
prev_bev
=
self
.
obtain_history_bev
(
prev_img
,
prev_img_metas
)
img_metas
=
[
each
[
len_queue
-
1
]
for
each
in
img_metas
]
if
not
img_metas
[
0
][
'prev_bev_exists'
]:
prev_bev
=
None
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
voxel_semantics
,
mask_camera
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
img_metas
,
img
=
None
,
voxel_semantics
=
None
,
mask_lidar
=
None
,
mask_camera
=
None
,
**
kwargs
):
for
var
,
name
in
[(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
img
=
[
img
]
if
img
is
None
else
img
if
img_metas
[
0
][
0
][
'scene_token'
]
!=
self
.
prev_frame_info
[
'scene_token'
]:
# the first sample of each scene is truncated
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# update idx
self
.
prev_frame_info
[
'scene_token'
]
=
img_metas
[
0
][
0
][
'scene_token'
]
# do not use temporal information
if
not
self
.
video_test_mode
:
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# Get the delta of ego position and angle between two timestamps.
tmp_pos
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][:
3
])
tmp_angle
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][
-
1
])
if
self
.
prev_frame_info
[
'prev_bev'
]
is
not
None
:
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
-=
self
.
prev_frame_info
[
'prev_pos'
]
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
-=
self
.
prev_frame_info
[
'prev_angle'
]
else
:
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
=
0
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
=
0
new_prev_bev
,
occ_results
=
self
.
simple_test
(
img_metas
[
0
],
img
[
0
],
prev_bev
=
self
.
prev_frame_info
[
'prev_bev'
],
**
kwargs
)
# During inference, we save the BEV features and ego motion of each timestamp.
self
.
prev_frame_info
[
'prev_pos'
]
=
tmp_pos
self
.
prev_frame_info
[
'prev_angle'
]
=
tmp_angle
self
.
prev_frame_info
[
'prev_bev'
]
=
new_prev_bev
return
occ_results
def
simple_test_pts
(
self
,
x
,
img_metas
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function"""
outs
=
self
.
pts_bbox_head
(
x
,
img_metas
,
prev_bev
=
prev_bev
,
test
=
True
)
occ
=
self
.
pts_bbox_head
.
get_occ
(
outs
,
img_metas
,
rescale
=
rescale
)
return
outs
[
'bev_embed'
],
occ
def
simple_test
(
self
,
img_metas
,
img
=
None
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
# bbox_list = [dict() for i in range(len(img_metas))]
new_prev_bev
,
occ
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
prev_bev
,
rescale
=
rescale
)
# for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
# result_dict['pts_bbox'] = pts_bbox
return
new_prev_bev
,
occ
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/__init__.py
0 → 100644
View file @
b64d9ca3
from
.custom_hooks
import
TransferWeight
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
0 → 100644
View file @
b64d9ca3
from
mmcv.runner.hooks.hook
import
HOOKS
,
Hook
from
projects.mmdet3d_plugin.models.utils
import
run_time
@
HOOKS
.
register_module
()
class
TransferWeight
(
Hook
):
def
__init__
(
self
,
every_n_inters
=
1
):
self
.
every_n_inters
=
every_n_inters
def
after_train_iter
(
self
,
runner
):
if
self
.
every_n_inner_iters
(
runner
,
self
.
every_n_inters
):
runner
.
eval_model
.
load_state_dict
(
runner
.
model
.
state_dict
())
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/__init__.py
0 → 100644
View file @
b64d9ca3
from
.transformer
import
PerceptionTransformer
from
.spatial_cross_attention
import
SpatialCrossAttention
,
MSDeformableAttention3D
from
.temporal_self_attention
import
TemporalSelfAttention
from
.encoder
import
BEVFormerEncoder
,
BEVFormerLayer
from
.decoder
import
DetectionTransformerDecoder
from
.transformer_occ
import
TransformerOcc
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/custom_base_transformer_layer.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
copy
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv
import
ConfigDict
,
deprecated_api_warning
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
build_norm_layer
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
FEEDFORWARD_NETWORK
,
POSITIONAL_ENCODING
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
# noqa F401
warnings
.
warn
(
ImportWarning
(
'``MultiScaleDeformableAttention`` has been moved to '
'``mmcv.ops.multi_scale_deform_attn``, please change original path '
# noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` '
# noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` '
# noqa E501
))
except
ImportError
:
warnings
.
warn
(
'Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. '
)
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
,
build_attention
@
TRANSFORMER_LAYER
.
register_module
()
class
MyCustomBaseTransformerLayer
(
BaseModule
):
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def
__init__
(
self
,
attn_cfgs
=
None
,
ffn_cfgs
=
dict
(
type
=
'FFN'
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
ffn_drop
=
0.
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
),
operation_order
=
None
,
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
,
batch_first
=
True
,
**
kwargs
):
deprecated_args
=
dict
(
feedforward_channels
=
'feedforward_channels'
,
ffn_dropout
=
'ffn_drop'
,
ffn_num_fcs
=
'num_fcs'
)
for
ori_name
,
new_name
in
deprecated_args
.
items
():
if
ori_name
in
kwargs
:
warnings
.
warn
(
f
'The arguments `
{
ori_name
}
` in BaseTransformerLayer '
f
'has been deprecated, now you should set `
{
new_name
}
` '
f
'and other FFN related arguments '
f
'to a dict named `ffn_cfgs`. '
)
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
super
(
MyCustomBaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
self
.
batch_first
=
batch_first
assert
set
(
operation_order
)
&
set
(
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
set
(
operation_order
),
f
'The operation_order of'
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'contains all four operation type '
\
f
"
{
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
]
}
"
num_attn
=
operation_order
.
count
(
'self_attn'
)
+
operation_order
.
count
(
'cross_attn'
)
if
isinstance
(
attn_cfgs
,
dict
):
attn_cfgs
=
[
copy
.
deepcopy
(
attn_cfgs
)
for
_
in
range
(
num_attn
)]
else
:
assert
num_attn
==
len
(
attn_cfgs
),
f
'The length '
\
f
'of attn_cfg
{
num_attn
}
is '
\
f
'not consistent with the number of attention'
\
f
'in operation_order
{
operation_order
}
.'
self
.
num_attn
=
num_attn
self
.
operation_order
=
operation_order
self
.
norm_cfg
=
norm_cfg
self
.
pre_norm
=
operation_order
[
0
]
==
'norm'
self
.
attentions
=
ModuleList
()
index
=
0
for
operation_name
in
operation_order
:
if
operation_name
in
[
'self_attn'
,
'cross_attn'
]:
if
'batch_first'
in
attn_cfgs
[
index
]:
assert
self
.
batch_first
==
attn_cfgs
[
index
][
'batch_first'
]
else
:
attn_cfgs
[
index
][
'batch_first'
]
=
self
.
batch_first
attention
=
build_attention
(
attn_cfgs
[
index
])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention
.
operation_name
=
operation_name
self
.
attentions
.
append
(
attention
)
index
+=
1
self
.
embed_dims
=
self
.
attentions
[
0
].
embed_dims
self
.
ffns
=
ModuleList
()
num_ffns
=
operation_order
.
count
(
'ffn'
)
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
ConfigDict
(
ffn_cfgs
)
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
[
copy
.
deepcopy
(
ffn_cfgs
)
for
_
in
range
(
num_ffns
)]
assert
len
(
ffn_cfgs
)
==
num_ffns
for
ffn_index
in
range
(
num_ffns
):
if
'embed_dims'
not
in
ffn_cfgs
[
ffn_index
]:
ffn_cfgs
[
'embed_dims'
]
=
self
.
embed_dims
else
:
# print()
# print('ffn_cfgs ',ffn_cfgs[ffn_index]['embed_dims'] ,self.embed_dims)
assert
ffn_cfgs
[
ffn_index
][
'embed_dims'
]
==
self
.
embed_dims
self
.
ffns
.
append
(
build_feedforward_network
(
ffn_cfgs
[
ffn_index
]))
self
.
norms
=
ModuleList
()
num_norms
=
operation_order
.
count
(
'norm'
)
for
_
in
range
(
num_norms
):
self
.
norms
.
append
(
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
if
layer
==
'self_attn'
:
temp_key
=
temp_value
=
query
query
=
self
.
attentions
[
attn_index
](
query
,
temp_key
,
temp_value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
query_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/decoder.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
mmcv
import
cv2
as
cv
import
copy
import
warnings
from
matplotlib
import
pyplot
as
plt
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
import
math
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
DetectionTransformerDecoder
(
TransformerLayerSequence
):
"""Implements the decoder in DETR3D transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
**
kwargs
):
super
(
DetectionTransformerDecoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
fp16_enabled
=
False
def
forward
(
self
,
query
,
*
args
,
reference_points
=
None
,
reg_branches
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
query
intermediate
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
reference_points
[...,
:
2
].
unsqueeze
(
2
)
# BS NUM_QUERY NUM_LEVEL 2
output
=
layer
(
output
,
*
args
,
reference_points
=
reference_points_input
,
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
assert
reference_points
.
shape
[
-
1
]
==
3
new_reference_points
=
torch
.
zeros_like
(
reference_points
)
new_reference_points
[...,
:
2
]
=
tmp
[
...,
:
2
]
+
inverse_sigmoid
(
reference_points
[...,
:
2
])
new_reference_points
[...,
2
:
3
]
=
tmp
[
...,
4
:
5
]
+
inverse_sigmoid
(
reference_points
[...,
2
:
3
])
new_reference_points
=
new_reference_points
.
sigmoid
()
reference_points
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
return
output
,
reference_points
@
ATTENTION
.
register_module
()
class
CustomMSDeformableAttention
(
BaseModule
):
"""An attention module used in Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
False
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiScaleDeformableAttention'
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
flag
=
'decoder'
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
# using fp16 deformable attention is unstable because it performs many sum operations
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/encoder.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
.custom_base_transformer_layer
import
MyCustomBaseTransformerLayer
import
copy
import
warnings
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
from
mmcv.runner
import
force_fp32
,
auto_fp16
import
numpy
as
np
import
torch
import
cv2
as
cv
import
mmcv
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BEVFormerEncoder
(
TransformerLayerSequence
):
"""
Attention with both self and cross
Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
pc_range
=
None
,
num_points_in_pillar
=
4
,
return_intermediate
=
False
,
dataset_type
=
'nuscenes'
,
**
kwargs
):
super
(
BEVFormerEncoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
num_points_in_pillar
=
num_points_in_pillar
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
@
staticmethod
def
get_reference_points
(
H
,
W
,
Z
=
8
,
num_points_in_pillar
=
4
,
dim
=
'3d'
,
bs
=
1
,
device
=
'cuda'
,
dtype
=
torch
.
float
):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of bev.
Z: hight of pillar.
D: sample D points uniformly from each pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
if
dim
==
'3d'
:
zs
=
torch
.
linspace
(
0.5
,
Z
-
0.5
,
num_points_in_pillar
,
dtype
=
dtype
,
device
=
device
).
view
(
-
1
,
1
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
Z
xs
=
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
1
,
W
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
W
ys
=
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
H
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
H
ref_3d
=
torch
.
stack
((
xs
,
ys
,
zs
),
-
1
)
ref_3d
=
ref_3d
.
permute
(
0
,
3
,
1
,
2
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
ref_3d
=
ref_3d
[
None
].
repeat
(
bs
,
1
,
1
,
1
)
#shape: (bs,num_points_in_pillar,h*w,3)
return
ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif
dim
==
'2d'
:
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
),
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
)
)
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
H
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
W
ref_2d
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
ref_2d
=
ref_2d
.
repeat
(
bs
,
1
,
1
).
unsqueeze
(
2
)
return
ref_2d
# This function must use fp32!!!
@
force_fp32
(
apply_to
=
(
'reference_points'
,
'img_metas'
))
def
point_sampling
(
self
,
reference_points
,
pc_range
,
img_metas
):
ego2lidar
=
img_metas
[
0
][
'ego2lidar'
]
lidar2img
=
[]
for
img_meta
in
img_metas
:
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
lidar2img
=
np
.
asarray
(
lidar2img
)
lidar2img
=
reference_points
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
ego2lidar
=
reference_points
.
new_tensor
(
ego2lidar
)
# ego2lidar = ego2lidar.unsqueeze(dim=0).repeat(num_imgs,1,1).unsqueeze(0)
reference_points
=
reference_points
.
clone
()
reference_points
[...,
0
:
1
]
=
reference_points
[...,
0
:
1
]
*
\
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
]
reference_points
[...,
1
:
2
]
=
reference_points
[...,
1
:
2
]
*
\
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
]
reference_points
[...,
2
:
3
]
=
reference_points
[...,
2
:
3
]
*
\
(
pc_range
[
5
]
-
pc_range
[
2
])
+
pc_range
[
2
]
reference_points
=
torch
.
cat
(
(
reference_points
,
torch
.
ones_like
(
reference_points
[...,
:
1
])),
-
1
)
reference_points
=
reference_points
.
permute
(
1
,
0
,
2
,
3
)
#shape: (num_points_in_pillar,bs,h*w,4)
D
,
B
,
num_query
=
reference_points
.
size
()[:
3
]
# D=num_points_in_pillar , num_query=h*w
num_cam
=
lidar2img
.
size
(
1
)
reference_points
=
reference_points
.
view
(
D
,
B
,
1
,
num_query
,
4
).
repeat
(
1
,
1
,
num_cam
,
1
,
1
).
unsqueeze
(
-
1
)
#shape: (num_points_in_pillar,bs,num_cam,h*w,4)
lidar2img
=
lidar2img
.
view
(
1
,
B
,
num_cam
,
1
,
4
,
4
).
repeat
(
D
,
1
,
1
,
num_query
,
1
,
1
)
ego2lidar
=
ego2lidar
.
view
(
1
,
1
,
1
,
1
,
4
,
4
).
repeat
(
D
,
1
,
num_cam
,
num_query
,
1
,
1
)
reference_points_cam
=
torch
.
matmul
(
torch
.
matmul
(
lidar2img
.
to
(
torch
.
float32
),
ego2lidar
.
to
(
torch
.
float32
)),
reference_points
.
to
(
torch
.
float32
)).
squeeze
(
-
1
)
eps
=
1e-5
bev_mask
=
(
reference_points_cam
[...,
2
:
3
]
>
eps
)
reference_points_cam
=
reference_points_cam
[...,
0
:
2
]
/
torch
.
maximum
(
reference_points_cam
[...,
2
:
3
],
torch
.
ones_like
(
reference_points_cam
[...,
2
:
3
])
*
eps
)
reference_points_cam
[...,
0
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
1
]
reference_points_cam
[...,
1
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
0
]
bev_mask
=
(
bev_mask
&
(
reference_points_cam
[...,
1
:
2
]
>
0.0
)
&
(
reference_points_cam
[...,
1
:
2
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
>
0.0
))
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
bev_mask
=
torch
.
nan_to_num
(
bev_mask
)
else
:
bev_mask
=
bev_mask
.
new_tensor
(
np
.
nan_to_num
(
bev_mask
.
cpu
().
numpy
()))
reference_points_cam
=
reference_points_cam
.
permute
(
2
,
1
,
3
,
0
,
4
)
#shape: (num_cam,bs,h*w,num_points_in_pillar,2)
bev_mask
=
bev_mask
.
permute
(
2
,
1
,
3
,
0
,
4
).
squeeze
(
-
1
)
return
reference_points_cam
,
bev_mask
@
auto_fp16
()
def
forward
(
self
,
bev_query
,
key
,
value
,
*
args
,
bev_h
=
None
,
bev_w
=
None
,
bev_pos
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
valid_ratios
=
None
,
prev_bev
=
None
,
shift
=
0.
,
**
kwargs
):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
bev_query
intermediate
=
[]
ref_3d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
],
self
.
num_points_in_pillar
,
dim
=
'3d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
ref_2d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
dim
=
'2d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
reference_points_cam
,
bev_mask
=
self
.
point_sampling
(
ref_3d
,
self
.
pc_range
,
kwargs
[
'img_metas'
])
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d
=
ref_2d
# .clone()
shift_ref_2d
+=
shift
[:,
None
,
None
,
:]
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query
=
bev_query
.
permute
(
1
,
0
,
2
)
bev_pos
=
bev_pos
.
permute
(
1
,
0
,
2
)
bs
,
len_bev
,
num_bev_level
,
_
=
ref_2d
.
shape
if
prev_bev
is
not
None
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
prev_bev
=
torch
.
stack
(
[
prev_bev
,
bev_query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
-
1
)
hybird_ref_2d
=
torch
.
stack
([
shift_ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
else
:
hybird_ref_2d
=
torch
.
stack
([
ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
for
lid
,
layer
in
enumerate
(
self
.
layers
):
output
=
layer
(
bev_query
,
key
,
value
,
*
args
,
bev_pos
=
bev_pos
,
ref_2d
=
hybird_ref_2d
,
ref_3d
=
ref_3d
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
reference_points_cam
=
reference_points_cam
,
bev_mask
=
bev_mask
,
prev_bev
=
prev_bev
,
**
kwargs
)
bev_query
=
output
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
)
return
output
@
TRANSFORMER_LAYER
.
register_module
()
class
BEVFormerLayer
(
MyCustomBaseTransformerLayer
):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def
__init__
(
self
,
attn_cfgs
,
feedforward_channels
,
ffn_dropout
=
0.0
,
operation_order
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'LN'
),
ffn_num_fcs
=
2
,
**
kwargs
):
super
(
BEVFormerLayer
,
self
).
__init__
(
attn_cfgs
=
attn_cfgs
,
feedforward_channels
=
feedforward_channels
,
ffn_dropout
=
ffn_dropout
,
operation_order
=
operation_order
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
ffn_num_fcs
=
ffn_num_fcs
,
**
kwargs
)
self
.
fp16_enabled
=
False
assert
len
(
operation_order
)
==
6
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
bev_pos
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
ref_2d
=
None
,
ref_3d
=
None
,
bev_h
=
None
,
bev_w
=
None
,
reference_points_cam
=
None
,
mask
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
# temporal self attention
if
layer
==
'self_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
prev_bev
,
prev_bev
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
bev_pos
,
key_pos
=
bev_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
reference_points
=
ref_2d
,
spatial_shapes
=
torch
.
tensor
(
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
# spaital cross attention
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
reference_points
=
ref_3d
,
reference_points_cam
=
reference_points_cam
,
mask
=
mask
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/multi_scale_deformable_attn_function.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
torch
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
class
MultiScaleDeformableAttnFunction_fp16
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
class
MultiScaleDeformableAttnFunction_fp32
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
ATTENTION
.
register_module
()
class
SpatialCrossAttention
(
BaseModule
):
"""An attention module used in BEVFormer.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_cams (int): The number of cameras
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
deformable_attention: (dict): The config for the deformable attention used in SCA.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_cams
=
6
,
pc_range
=
None
,
dropout
=
0.1
,
init_cfg
=
None
,
batch_first
=
False
,
deformable_attention
=
dict
(
type
=
'MSDeformableAttention3D'
,
embed_dims
=
256
,
num_levels
=
4
),
**
kwargs
):
super
(
SpatialCrossAttention
,
self
).
__init__
(
init_cfg
)
self
.
init_cfg
=
init_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
self
.
deformable_attention
=
build_attention
(
deformable_attention
)
self
.
embed_dims
=
embed_dims
self
.
num_cams
=
num_cams
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
batch_first
=
batch_first
self
.
init_weight
()
def
init_weight
(
self
):
"""Default initialization for Parameters of Module."""
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points_cam'
))
def
forward
(
self
,
query
,
key
,
value
,
residual
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
reference_points_cam
=
None
,
bev_mask
=
None
,
level_start_index
=
None
,
flag
=
'encoder'
,
**
kwargs
):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. (B, N, C, H, W)
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, 4),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
key
is
None
:
key
=
query
if
value
is
None
:
value
=
key
if
residual
is
None
:
inp_residual
=
query
slots
=
torch
.
zeros_like
(
query
)
if
query_pos
is
not
None
:
query
=
query
+
query_pos
bs
,
num_query
,
_
=
query
.
size
()
# bevformer reference_points_cam shape: (num_cam,bs,h*w,num_points_in_pillar,2)
D
=
reference_points_cam
.
size
(
3
)
indexes
=
[]
for
i
,
mask_per_img
in
enumerate
(
bev_mask
):
index_query_per_img
=
mask_per_img
[
0
].
sum
(
-
1
).
nonzero
().
squeeze
(
-
1
)
indexes
.
append
(
index_query_per_img
)
max_len
=
max
([
len
(
each
)
for
each
in
indexes
])
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch
=
query
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
])
reference_points_rebatch
=
reference_points_cam
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
D
,
2
])
for
j
in
range
(
bs
):
for
i
,
reference_points_per_img
in
enumerate
(
reference_points_cam
):
index_query_per_img
=
indexes
[
i
]
queries_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
query
[
j
,
index_query_per_img
]
reference_points_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
reference_points_per_img
[
j
,
index_query_per_img
]
num_cams
,
l
,
bs
,
embed_dims
=
key
.
shape
key
=
key
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
value
=
value
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
queries
=
self
.
deformable_attention
(
query
=
queries_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
self
.
embed_dims
),
key
=
key
,
value
=
value
,
reference_points
=
reference_points_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
D
,
2
),
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
).
view
(
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
)
for
j
in
range
(
bs
):
for
i
,
index_query_per_img
in
enumerate
(
indexes
):
slots
[
j
,
index_query_per_img
]
+=
queries
[
j
,
i
,
:
len
(
index_query_per_img
)]
count
=
bev_mask
.
sum
(
-
1
)
>
0
count
=
count
.
permute
(
1
,
2
,
0
).
sum
(
-
1
)
count
=
torch
.
clamp
(
count
,
min
=
1.0
)
slots
=
slots
/
count
[...,
None
]
slots
=
self
.
output_proj
(
slots
)
return
self
.
dropout
(
slots
)
+
inp_residual
@
ATTENTION
.
register_module
()
class
MSDeformableAttention3D
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
8
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
batch_first
=
batch_first
self
.
output_proj
=
None
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
"""
For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
For each referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points.
"""
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
bs
,
num_query
,
num_Z_anchors
,
xy
=
reference_points
.
shape
reference_points
=
reference_points
[:,
:,
None
,
None
,
None
,
:,
:]
sampling_offsets
=
sampling_offsets
/
\
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
=
sampling_offsets
.
shape
sampling_offsets
=
sampling_offsets
.
view
(
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
//
num_Z_anchors
,
num_Z_anchors
,
xy
)
sampling_locations
=
reference_points
+
sampling_offsets
bs
,
num_query
,
num_heads
,
num_levels
,
num_points
,
num_Z_anchors
,
xy
=
sampling_locations
.
shape
assert
num_all_points
==
num_points
*
num_Z_anchors
sampling_locations
=
sampling_locations
.
view
(
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
)
elif
reference_points
.
shape
[
-
1
]
==
4
:
assert
False
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2
# attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
#
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
output
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
ATTENTION
import
math
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
ATTENTION
.
register_module
()
class
TemporalSelfAttention
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to True.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV.
the length of BEV queue is 2.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
num_bev_queue
=
2
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
num_bev_queue
=
num_bev_queue
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
*
self
.
num_bev_queue
,
num_bev_queue
*
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
*
self
.
num_bev_queue
,
num_bev_queue
*
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
*
self
.
num_bev_queue
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
flag
=
'decoder'
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
assert
self
.
batch_first
bs
,
len_bev
,
c
=
query
.
shape
value
=
torch
.
stack
([
query
,
query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
c
)
# value = torch.cat([query, query], 0)
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
embed_dims
=
query
.
shape
_
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
assert
self
.
num_bev_queue
==
2
query
=
torch
.
cat
([
value
[:
bs
],
query
],
-
1
)
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
)
sampling_offsets
=
sampling_offsets
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
,
self
.
num_points
)
attention_weights
=
attention_weights
.
permute
(
0
,
3
,
1
,
2
,
4
,
5
)
\
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
).
contiguous
()
sampling_offsets
=
sampling_offsets
.
permute
(
0
,
3
,
1
,
2
,
4
,
5
,
6
)
\
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
# using fp16 deformable attention is unstable because it performs many sum operations
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
# output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
output
=
output
.
permute
(
1
,
2
,
0
)
# fuse history value and current value
# (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
output
=
output
.
view
(
num_query
,
embed_dims
,
bs
,
self
.
num_bev_queue
)
output
=
output
.
mean
(
-
1
)
# (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
output
=
output
.
permute
(
2
,
0
,
1
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
from
torch.nn.init
import
normal_
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.runner.base_module
import
BaseModule
from
torchvision.transforms.functional
import
rotate
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
TRANSFORMER
.
register_module
()
class
PerceptionTransformer
(
BaseModule
):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
decoder
=
None
,
embed_dims
=
256
,
rotate_prev_bev
=
True
,
use_shift
=
True
,
use_can_bus
=
True
,
can_bus_norm
=
True
,
use_cams_embeds
=
True
,
rotate_center
=
[
100
,
100
],
**
kwargs
):
super
(
PerceptionTransformer
,
self
).
__init__
(
**
kwargs
)
self
.
encoder
=
build_transformer_layer_sequence
(
encoder
)
self
.
decoder
=
build_transformer_layer_sequence
(
decoder
)
self
.
embed_dims
=
embed_dims
self
.
num_feature_levels
=
num_feature_levels
self
.
num_cams
=
num_cams
self
.
fp16_enabled
=
False
self
.
rotate_prev_bev
=
rotate_prev_bev
self
.
use_shift
=
use_shift
self
.
use_can_bus
=
use_can_bus
self
.
can_bus_norm
=
can_bus_norm
self
.
use_cams_embeds
=
use_cams_embeds
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
init_layers
()
self
.
rotate_center
=
rotate_center
def
init_layers
(
self
):
"""Initialize layers of the Detr3DTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
cams_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_cams
,
self
.
embed_dims
))
self
.
reference_points
=
nn
.
Linear
(
self
.
embed_dims
,
3
)
self
.
can_bus_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
18
,
self
.
embed_dims
//
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
embed_dims
//
2
,
self
.
embed_dims
),
nn
.
ReLU
(
inplace
=
True
),
)
if
self
.
can_bus_norm
:
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
normal_
(
self
.
level_embeds
)
normal_
(
self
.
cams_embeds
)
xavier_init
(
self
.
reference_points
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
can_bus_mlp
,
distribution
=
'uniform'
,
bias
=
0.
)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'prev_bev'
,
'bev_pos'
))
def
get_bev_features
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""
obtain bev features.
"""
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_queries
=
bev_queries
.
unsqueeze
(
1
).
repeat
(
1
,
bs
,
1
)
bev_pos
=
bev_pos
.
flatten
(
2
).
permute
(
2
,
0
,
1
)
# obtain rotation angle and shift with ego motion
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
for
each
in
kwargs
[
'img_metas'
]])
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
for
each
in
kwargs
[
'img_metas'
]])
ego_angle
=
np
.
array
(
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
kwargs
[
'img_metas'
]])
grid_length_y
=
grid_length
[
0
]
grid_length_x
=
grid_length
[
1
]
translation_length
=
np
.
sqrt
(
delta_x
**
2
+
delta_y
**
2
)
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
bev_angle
=
ego_angle
-
translation_angle
shift_y
=
translation_length
*
\
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
bev_h
shift_x
=
translation_length
*
\
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
bev_w
shift_y
=
shift_y
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift
=
bev_queries
.
new_tensor
(
[
shift_x
,
shift_y
]).
permute
(
1
,
0
)
# xy, bs -> bs, xy
if
prev_bev
is
not
None
:
if
prev_bev
.
shape
[
1
]
==
bev_h
*
bev_w
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
if
self
.
rotate_prev_bev
:
for
i
in
range
(
bs
):
# num_prev_bev = prev_bev.size(1)
rotation_angle
=
kwargs
[
'img_metas'
][
i
][
'can_bus'
][
-
1
]
tmp_prev_bev
=
prev_bev
[:,
i
].
reshape
(
bev_h
,
bev_w
,
-
1
).
permute
(
2
,
0
,
1
)
tmp_prev_bev
=
rotate
(
tmp_prev_bev
,
rotation_angle
,
center
=
self
.
rotate_center
)
tmp_prev_bev
=
tmp_prev_bev
.
permute
(
1
,
2
,
0
).
reshape
(
bev_h
*
bev_w
,
1
,
-
1
)
prev_bev
[:,
i
]
=
tmp_prev_bev
[:,
0
]
# add can bus signals
can_bus
=
bev_queries
.
new_tensor
(
[
each
[
'can_bus'
]
for
each
in
kwargs
[
'img_metas'
]])
# [:, :]
can_bus
=
self
.
can_bus_mlp
(
can_bus
)[
None
,
:,
:]
bev_queries
=
bev_queries
+
can_bus
*
self
.
use_can_bus
feat_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
feat
in
enumerate
(
mlvl_feats
):
bs
,
num_cam
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
feat
=
feat
.
flatten
(
3
).
permute
(
1
,
0
,
3
,
2
)
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
2
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
bev_pos
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
feat_flatten
=
feat_flatten
.
permute
(
0
,
2
,
1
,
3
)
# (num_cam, H*W, bs, embed_dims)
bev_embed
=
self
.
encoder
(
bev_queries
,
feat_flatten
,
feat_flatten
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
prev_bev
,
shift
=
shift
,
**
kwargs
)
return
bev_embed
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'object_query_embed'
,
'prev_bev'
,
'bev_pos'
))
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
object_query_embed
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
bev_embed
=
self
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
grid_length
,
bev_pos
=
bev_pos
,
prev_bev
=
prev_bev
,
**
kwargs
)
# bev_embed shape: bs, bev_h*bev_w, embed_dims
bs
=
mlvl_feats
[
0
].
size
(
0
)
query_pos
,
query
=
torch
.
split
(
object_query_embed
,
self
.
embed_dims
,
dim
=
1
)
query_pos
=
query_pos
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
query
=
query
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
reference_points
=
self
.
reference_points
(
query_pos
)
reference_points
=
reference_points
.
sigmoid
()
init_reference_out
=
reference_points
query
=
query
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
bev_embed
=
bev_embed
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
key
=
None
,
value
=
bev_embed
,
query_pos
=
query_pos
,
reference_points
=
reference_points
,
reg_branches
=
reg_branches
,
cls_branches
=
cls_branches
,
spatial_shapes
=
torch
.
tensor
([[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
inter_references_out
=
inter_references
return
bev_embed
,
inter_states
,
init_reference_out
,
inter_references_out
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer_occ.py
0 → 100644
View file @
b64d9ca3
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Xiaoyu Tian
# ---------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
from
torch.nn.init
import
normal_
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.runner.base_module
import
BaseModule
from
torchvision.transforms.functional
import
rotate
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn
import
PLUGIN_LAYERS
,
Conv2d
,
Conv3d
,
ConvModule
,
caffe2_xavier_init
@
TRANSFORMER
.
register_module
()
class
TransformerOcc
(
BaseModule
):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
decoder
=
None
,
embed_dims
=
256
,
rotate_prev_bev
=
True
,
use_shift
=
True
,
use_can_bus
=
True
,
can_bus_norm
=
True
,
use_cams_embeds
=
True
,
use_3d
=
False
,
use_conv
=
False
,
rotate_center
=
[
100
,
100
],
num_classes
=
18
,
out_dim
=
32
,
pillar_h
=
16
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'BN'
,
),
norm_cfg_3d
=
dict
(
type
=
'BN3d'
,
),
**
kwargs
):
super
(
TransformerOcc
,
self
).
__init__
(
**
kwargs
)
self
.
encoder
=
build_transformer_layer_sequence
(
encoder
)
self
.
embed_dims
=
embed_dims
self
.
num_feature_levels
=
num_feature_levels
self
.
num_cams
=
num_cams
self
.
fp16_enabled
=
False
self
.
rotate_prev_bev
=
rotate_prev_bev
self
.
use_shift
=
use_shift
self
.
use_can_bus
=
use_can_bus
self
.
can_bus_norm
=
can_bus_norm
self
.
use_cams_embeds
=
use_cams_embeds
self
.
use_3d
=
use_3d
self
.
use_conv
=
use_conv
self
.
pillar_h
=
pillar_h
self
.
out_dim
=
out_dim
if
not
use_3d
:
if
use_conv
:
use_bias
=
norm_cfg
is
None
self
.
decoder
=
nn
.
Sequential
(
ConvModule
(
self
.
embed_dims
,
self
.
embed_dims
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
),
ConvModule
(
self
.
embed_dims
,
self
.
embed_dims
*
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
),)
else
:
self
.
decoder
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
embed_dims
*
2
),
)
else
:
use_bias_3d
=
norm_cfg_3d
is
None
self
.
middle_dims
=
self
.
embed_dims
//
pillar_h
self
.
decoder
=
nn
.
Sequential
(
ConvModule
(
self
.
middle_dims
,
self
.
out_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias_3d
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
norm_cfg_3d
,
act_cfg
=
act_cfg
),
ConvModule
(
self
.
out_dim
,
self
.
out_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias_3d
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
norm_cfg_3d
,
act_cfg
=
act_cfg
),
)
self
.
predicter
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_dim
,
self
.
out_dim
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
out_dim
*
2
,
num_classes
),
)
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
init_layers
()
self
.
rotate_center
=
rotate_center
def
init_layers
(
self
):
"""Initialize layers of the Detr3DTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
cams_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_cams
,
self
.
embed_dims
))
# self.reference_points = nn.Linear(self.embed_dims, 3)
self
.
can_bus_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
18
,
self
.
embed_dims
//
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
embed_dims
//
2
,
self
.
embed_dims
),
nn
.
ReLU
(
inplace
=
True
),
)
if
self
.
can_bus_norm
:
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
normal_
(
self
.
level_embeds
)
normal_
(
self
.
cams_embeds
)
# xavier_init(self.reference_points, distribution='uniform', bias=0.)
xavier_init
(
self
.
can_bus_mlp
,
distribution
=
'uniform'
,
bias
=
0.
)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'prev_bev'
,
'bev_pos'
))
def
get_bev_features
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""
obtain bev features.
"""
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_queries
=
bev_queries
.
unsqueeze
(
1
).
repeat
(
1
,
bs
,
1
)
bev_pos
=
bev_pos
.
flatten
(
2
).
permute
(
2
,
0
,
1
)
# obtain rotation angle and shift with ego motion
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
for
each
in
kwargs
[
'img_metas'
]])
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
for
each
in
kwargs
[
'img_metas'
]])
ego_angle
=
np
.
array
(
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
kwargs
[
'img_metas'
]])
grid_length_y
=
grid_length
[
0
]
grid_length_x
=
grid_length
[
1
]
translation_length
=
np
.
sqrt
(
delta_x
**
2
+
delta_y
**
2
)
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
bev_angle
=
ego_angle
-
translation_angle
shift_y
=
translation_length
*
\
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
bev_h
shift_x
=
translation_length
*
\
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
bev_w
shift_y
=
shift_y
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift
=
bev_queries
.
new_tensor
(
[
shift_x
,
shift_y
]).
permute
(
1
,
0
)
# xy, bs -> bs, xy
if
prev_bev
is
not
None
:
if
prev_bev
.
shape
[
1
]
==
bev_h
*
bev_w
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
elif
len
(
prev_bev
.
shape
)
==
4
:
prev_bev
=
prev_bev
.
view
(
bs
,
-
1
,
bev_h
*
bev_w
).
permute
(
2
,
0
,
1
)
if
self
.
rotate_prev_bev
:
for
i
in
range
(
bs
):
# num_prev_bev = prev_bev.size(1)
rotation_angle
=
kwargs
[
'img_metas'
][
i
][
'can_bus'
][
-
1
]
tmp_prev_bev
=
prev_bev
[:,
i
].
reshape
(
bev_h
,
bev_w
,
-
1
).
permute
(
2
,
0
,
1
)
tmp_prev_bev
=
rotate
(
tmp_prev_bev
,
rotation_angle
,
center
=
self
.
rotate_center
)
tmp_prev_bev
=
tmp_prev_bev
.
permute
(
1
,
2
,
0
).
reshape
(
bev_h
*
bev_w
,
1
,
-
1
)
prev_bev
[:,
i
]
=
tmp_prev_bev
[:,
0
]
# add can bus signals
can_bus
=
bev_queries
.
new_tensor
(
[
each
[
'can_bus'
]
for
each
in
kwargs
[
'img_metas'
]])
# [:, :]
can_bus
=
self
.
can_bus_mlp
(
can_bus
)[
None
,
:,
:]
bev_queries
=
bev_queries
+
can_bus
*
self
.
use_can_bus
feat_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
feat
in
enumerate
(
mlvl_feats
):
bs
,
num_cam
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
feat
=
feat
.
flatten
(
3
).
permute
(
1
,
0
,
3
,
2
)
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
2
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
bev_pos
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
feat_flatten
=
feat_flatten
.
permute
(
0
,
2
,
1
,
3
)
# (num_cam, H*W, bs, embed_dims)
bev_embed
=
self
.
encoder
(
bev_queries
,
feat_flatten
,
feat_flatten
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
prev_bev
,
shift
=
shift
,
**
kwargs
)
return
bev_embed
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'object_query_embed'
,
'prev_bev'
,
'bev_pos'
))
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
object_query_embed
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
bev_embed
=
self
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
grid_length
,
bev_pos
=
bev_pos
,
prev_bev
=
prev_bev
,
**
kwargs
)
# bev_embed shape: bs, bev_h*bev_w, embed_dims
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_embed
=
bev_embed
.
permute
(
0
,
2
,
1
).
view
(
bs
,
-
1
,
bev_h
,
bev_w
)
if
self
.
use_3d
:
outputs
=
self
.
decoder
(
bev_embed
.
view
(
bs
,
-
1
,
self
.
pillar_h
,
bev_h
,
bev_w
))
outputs
=
outputs
.
permute
(
0
,
4
,
3
,
2
,
1
)
elif
self
.
use_conv
:
outputs
=
self
.
decoder
(
bev_embed
)
outputs
=
outputs
.
view
(
bs
,
-
1
,
self
.
pillar_h
,
bev_h
,
bev_w
).
permute
(
0
,
3
,
4
,
2
,
1
)
else
:
outputs
=
self
.
decoder
(
bev_embed
.
permute
(
0
,
2
,
3
,
1
))
outputs
=
outputs
.
view
(
bs
,
bev_h
,
bev_w
,
self
.
pillar_h
,
self
.
out_dim
)
outputs
=
self
.
predicter
(
outputs
)
# print('outputs',type(outputs))
return
bev_embed
,
outputs
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/__init__.py
0 → 100644
View file @
b64d9ca3
from
.epoch_based_runner
import
EpochBasedRunner_video
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py
0 → 100644
View file @
b64d9ca3
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
os.path
as
osp
import
torch
import
mmcv
from
mmcv.runner.base_runner
import
BaseRunner
from
mmcv.runner.epoch_based_runner
import
EpochBasedRunner
from
mmcv.runner.builder
import
RUNNERS
from
mmcv.runner.checkpoint
import
save_checkpoint
from
mmcv.runner.utils
import
get_host_info
from
pprint
import
pprint
from
mmcv.parallel.data_container
import
DataContainer
@
RUNNERS
.
register_module
()
class
EpochBasedRunner_video
(
EpochBasedRunner
):
'''
# basic logic
input_sequence = [a, b, c] # given a sequence of samples
prev_bev = None
for each in input_sequcene[:-1]
prev_bev = eval_model(each, prev_bev)) # inference only.
model(input_sequcene[-1], prev_bev) # train the last sample.
'''
def
__init__
(
self
,
model
,
eval_model
=
None
,
batch_processor
=
None
,
optimizer
=
None
,
work_dir
=
None
,
logger
=
None
,
meta
=
None
,
keys
=
[
'gt_bboxes_3d'
,
'gt_labels_3d'
,
'img'
],
max_iters
=
None
,
max_epochs
=
None
):
super
().
__init__
(
model
,
batch_processor
,
optimizer
,
work_dir
,
logger
,
meta
,
max_iters
,
max_epochs
)
keys
.
append
(
'img_metas'
)
self
.
keys
=
keys
self
.
eval_model
=
eval_model
self
.
eval_model
.
eval
()
def
run_iter
(
self
,
data_batch
,
train_mode
,
**
kwargs
):
if
self
.
batch_processor
is
not
None
:
assert
False
# outputs = self.batch_processor(
# self.model, data_batch, train_mode=train_mode, **kwargs)
elif
train_mode
:
num_samples
=
data_batch
[
'img'
].
data
[
0
].
size
(
1
)
data_list
=
[]
prev_bev
=
None
for
i
in
range
(
num_samples
):
data
=
{}
for
key
in
self
.
keys
:
if
key
not
in
[
'img_metas'
,
'img'
,
'points'
]:
data
[
key
]
=
data_batch
[
key
]
else
:
if
key
==
'img'
:
data
[
'img'
]
=
DataContainer
(
data
=
[
data_batch
[
'img'
].
data
[
0
][:,
i
]],
cpu_only
=
data_batch
[
'img'
].
cpu_only
,
stack
=
True
)
elif
key
==
'img_metas'
:
data
[
'img_metas'
]
=
DataContainer
(
data
=
[[
each
[
i
]
for
each
in
data_batch
[
'img_metas'
].
data
[
0
]]],
cpu_only
=
data_batch
[
'img_metas'
].
cpu_only
)
else
:
assert
False
data_list
.
append
(
data
)
with
torch
.
no_grad
():
for
i
in
range
(
num_samples
-
1
):
if
data_list
[
i
][
'img_metas'
].
data
[
0
][
0
][
'prev_bev_exists'
]:
data_list
[
i
][
'prev_bev'
]
=
DataContainer
(
data
=
[
prev_bev
],
cpu_only
=
False
)
prev_bev
=
self
.
eval_model
.
val_step
(
data_list
[
i
],
self
.
optimizer
,
**
kwargs
)
if
data_list
[
-
1
][
'img_metas'
].
data
[
0
][
0
][
'prev_bev_exists'
]:
data_list
[
-
1
][
'prev_bev'
]
=
DataContainer
(
data
=
[
prev_bev
],
cpu_only
=
False
)
outputs
=
self
.
model
.
train_step
(
data_list
[
-
1
],
self
.
optimizer
,
**
kwargs
)
else
:
assert
False
# outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if
not
isinstance
(
outputs
,
dict
):
raise
TypeError
(
'"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict'
)
if
'log_vars'
in
outputs
:
self
.
log_buffer
.
update
(
outputs
[
'log_vars'
],
outputs
[
'num_samples'
])
self
.
outputs
=
outputs
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/core/bbox/assigners/__init__.py
0 → 100644
View file @
b64d9ca3
from
.hungarian_assigner_3d
import
HungarianAssigner3D
__all__
=
[
'HungarianAssigner3D'
]
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment