Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
df3c64a9
"deploy/examples/llm/components/disagg_router.py" did not exist on "8891aa0cc8d1bad420890a60c5d63d253834ee60"
Commit
df3c64a9
authored
Apr 17, 2023
by
zhiqi-li
Browse files
support occupancy prediction
parent
bdd98bcb
Changes
160
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3400 additions
and
0 deletions
+3400
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/src/vision.cpp
...det3d_plugin/bevformer/backbones/ops_dcnv3/src/vision.cpp
+17
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/test.py
...ects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/test.py
+263
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
...projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_occ_head.py
...mdet3d_plugin/bevformer/dense_heads/bevformer_occ_head.py
+207
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/__init__.py
...n/projects/mmdet3d_plugin/bevformer/detectors/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/bevformer_occ.py
...jects/mmdet3d_plugin/bevformer/detectors/bevformer_occ.py
+298
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/__init__.py
...ction/projects/mmdet3d_plugin/bevformer/hooks/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
...n/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
+14
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/__init__.py
...ion/projects/mmdet3d_plugin/bevformer/modules/__init__.py
+6
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/custom_base_transformer_layer.py
...plugin/bevformer/modules/custom_base_transformer_layer.py
+262
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/decoder.py
...tion/projects/mmdet3d_plugin/bevformer/modules/decoder.py
+345
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/encoder.py
...tion/projects/mmdet3d_plugin/bevformer/modules/encoder.py
+408
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/multi_scale_deformable_attn_function.py
...bevformer/modules/multi_scale_deformable_attn_function.py
+163
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
...det3d_plugin/bevformer/modules/spatial_cross_attention.py
+400
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
...det3d_plugin/bevformer/modules/temporal_self_attention.py
+272
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer.py
.../projects/mmdet3d_plugin/bevformer/modules/transformer.py
+289
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer_occ.py
...jects/mmdet3d_plugin/bevformer/modules/transformer_occ.py
+352
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/__init__.py
...tion/projects/mmdet3d_plugin/bevformer/runner/__init__.py
+1
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py
...cts/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py
+97
-0
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/core/bbox/assigners/__init__.py
...n/projects/mmdet3d_plugin/core/bbox/assigners/__init__.py
+3
-0
No files found.
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/src/vision.cpp
0 → 100644
View file @
df3c64a9
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "dcnv3.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dcnv3_forward"
,
&
dcnv3_forward
,
"dcnv3_forward"
);
m
.
def
(
"dcnv3_backward"
,
&
dcnv3_backward
,
"dcnv3_backward"
);
}
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/backbones/ops_dcnv3/test.py
0 → 100644
View file @
df3c64a9
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
time
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
H_in
,
W_in
=
8
,
8
N
,
M
,
D
=
2
,
4
,
16
Kh
,
Kw
=
3
,
3
P
=
Kh
*
Kw
offset_scale
=
2.0
pad
=
1
dilation
=
1
stride
=
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
torch
.
manual_seed
(
3
)
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_double
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
output_pytorch
=
dcnv3_core_pytorch
(
input
.
double
(),
offset
.
double
(),
mask
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
).
detach
().
cpu
()
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input
.
double
(),
offset
.
double
(),
mask
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
).
detach
().
cpu
()
fwdok
=
torch
.
allclose
(
output_cuda
,
output_pytorch
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
print
(
'>>> forward double'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_float
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
output_pytorch
=
dcnv3_core_pytorch
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
).
detach
().
cpu
()
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
).
detach
().
cpu
()
fwdok
=
torch
.
allclose
(
output_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
print
(
'>>> forward float'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
def
check_backward_equal_with_pytorch_double
(
channels
=
4
,
grad_input
=
True
,
grad_offset
=
True
,
grad_mask
=
True
):
# H_in, W_in = 4, 4
N
=
2
M
=
2
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
D
=
channels
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
input0
.
requires_grad
=
grad_input
offset0
.
requires_grad
=
grad_offset
mask0
.
requires_grad
=
grad_mask
output_pytorch
=
dcnv3_core_pytorch
(
input0
.
double
(),
offset0
.
double
(),
mask0
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
)
output_pytorch
.
sum
().
backward
()
input1
=
input0
.
detach
()
offset1
=
offset0
.
detach
()
mask1
=
mask0
.
detach
()
input1
.
requires_grad
=
grad_input
offset1
.
requires_grad
=
grad_offset
mask1
.
requires_grad
=
grad_mask
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input1
.
double
(),
offset1
.
double
(),
mask1
.
double
(),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
)
output_cuda
.
sum
().
backward
()
print
(
f
'>>> backward double: channels
{
D
}
'
)
bwdok
=
torch
.
allclose
(
input0
.
grad
,
input1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
input0
.
grad
-
input1
.
grad
).
abs
().
max
()
max_rel_err
=
((
input0
.
grad
-
input1
.
grad
).
abs
()
/
input0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
input_grad check_backward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
offset0
.
grad
,
offset1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
offset0
.
grad
-
offset1
.
grad
).
abs
().
max
()
max_rel_err
=
((
offset0
.
grad
-
offset1
.
grad
).
abs
()
/
offset0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
offset_grad check_backward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
mask0
.
grad
,
mask1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
mask0
.
grad
-
mask1
.
grad
).
abs
().
max
()
max_rel_err
=
((
mask0
.
grad
-
mask1
.
grad
).
abs
()
/
mask0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
mask_grad check_backward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
def
check_backward_equal_with_pytorch_float
(
channels
=
4
,
grad_input
=
True
,
grad_offset
=
True
,
grad_mask
=
True
):
# H_in, W_in = 4, 4
N
=
2
M
=
2
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
D
=
channels
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
input0
.
requires_grad
=
grad_input
offset0
.
requires_grad
=
grad_offset
mask0
.
requires_grad
=
grad_mask
output_pytorch
=
dcnv3_core_pytorch
(
input0
,
offset0
,
mask0
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
)
output_pytorch
.
sum
().
backward
()
input1
=
input0
.
detach
()
offset1
=
offset0
.
detach
()
mask1
=
mask0
.
detach
()
input1
.
requires_grad
=
grad_input
offset1
.
requires_grad
=
grad_offset
mask1
.
requires_grad
=
grad_mask
im2col_step
=
2
output_cuda
=
DCNv3Function
.
apply
(
input1
,
offset1
,
mask1
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
)
output_cuda
.
sum
().
backward
()
print
(
f
'>>> backward float: channels
{
D
}
'
)
bwdok
=
torch
.
allclose
(
input0
.
grad
,
input1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
input0
.
grad
-
input1
.
grad
).
abs
().
max
()
max_rel_err
=
((
input0
.
grad
-
input1
.
grad
).
abs
()
/
input0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
input_grad check_backward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
offset0
.
grad
,
offset1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
offset0
.
grad
-
offset1
.
grad
).
abs
().
max
()
max_rel_err
=
((
offset0
.
grad
-
offset1
.
grad
).
abs
()
/
offset0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
offset_grad check_backward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
bwdok
=
torch
.
allclose
(
mask0
.
grad
,
mask1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
mask0
.
grad
-
mask1
.
grad
).
abs
().
max
()
max_rel_err
=
((
mask0
.
grad
-
mask1
.
grad
).
abs
()
/
mask0
.
grad
.
abs
()).
max
()
print
(
f
'*
{
bwdok
}
mask_grad check_backward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
@
torch
.
no_grad
()
def
check_time_cost
(
im2col_step
=
128
):
N
=
512
H_in
,
W_in
=
64
,
64
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
print
(
f
'>>> time cost: im2col_step
{
im2col_step
}
; input
{
input
.
shape
}
; points
{
P
}
'
)
repeat
=
100
for
i
in
range
(
repeat
):
output_cuda
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
1.0
,
im2col_step
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
repeat
):
output_cuda
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
1.0
,
im2col_step
)
torch
.
cuda
.
synchronize
()
print
(
f
'foward time cost:
{
(
time
.
time
()
-
start
)
/
repeat
}
'
)
if
__name__
==
'__main__'
:
check_forward_equal_with_pytorch_double
()
check_forward_equal_with_pytorch_float
()
for
channels
in
[
1
,
16
,
30
,
32
,
64
,
71
,
1025
]:
check_backward_equal_with_pytorch_double
(
channels
,
True
,
True
,
True
)
for
channels
in
[
1
,
16
,
30
,
32
,
64
,
71
,
1025
]:
check_backward_equal_with_pytorch_float
(
channels
,
True
,
True
,
True
)
for
i
in
range
(
3
):
im2col_step
=
128
*
(
2
**
i
)
check_time_cost
(
im2col_step
)
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py
0 → 100644
View file @
df3c64a9
from
.bevformer_occ_head
import
BEVFormerOccHead
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_occ_head.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Xiaoyu Tian
# ---------------------------------------------
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmdet.core
import
(
multi_apply
,
multi_apply
,
reduce_mean
)
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
HEADS
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet3d.core.bbox.coders
import
build_bbox_coder
from
projects.mmdet3d_plugin.core.bbox.util
import
normalize_bbox
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
import
numpy
as
np
import
mmcv
import
cv2
as
cv
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.cnn.bricks.transformer
import
build_positional_encoding
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.builder
import
build_loss
from
mmcv.runner
import
BaseModule
,
force_fp32
@
HEADS
.
register_module
()
class
BEVFormerOccHead
(
BaseModule
):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def
__init__
(
self
,
*
args
,
with_box_refine
=
False
,
as_two_stage
=
False
,
transformer
=
None
,
bbox_coder
=
None
,
num_cls_fcs
=
2
,
code_weights
=
None
,
pc_range
=
[
-
40
,
-
40
,
-
1.0
,
40
,
40
,
5.4
],
bev_h
=
30
,
bev_w
=
30
,
loss_occ
=
None
,
use_mask
=
False
,
positional_encoding
=
None
,
**
kwargs
):
self
.
bev_h
=
bev_h
self
.
bev_w
=
bev_w
self
.
fp16_enabled
=
False
self
.
num_classes
=
kwargs
[
'num_classes'
]
self
.
use_mask
=
use_mask
self
.
with_box_refine
=
with_box_refine
self
.
as_two_stage
=
as_two_stage
if
self
.
as_two_stage
:
transformer
[
'as_two_stage'
]
=
self
.
as_two_stage
self
.
pc_range
=
pc_range
self
.
real_w
=
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
]
self
.
real_h
=
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
]
self
.
num_cls_fcs
=
num_cls_fcs
-
1
super
(
BEVFormerOccHead
,
self
).
__init__
()
self
.
loss_occ
=
build_loss
(
loss_occ
)
self
.
positional_encoding
=
build_positional_encoding
(
positional_encoding
)
self
.
transformer
=
build_transformer
(
transformer
)
self
.
embed_dims
=
self
.
transformer
.
embed_dims
if
not
self
.
as_two_stage
:
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
def
init_weights
(
self
):
"""Initialize weights of the DeformDETR head."""
self
.
transformer
.
init_weights
()
# if self.loss_cls.use_sigmoid:
# bias_init = bias_init_with_prob(0.01)
# for m in self.cls_branches:
# nn.init.constant_(m[-1].bias, bias_init)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
))
def
forward
(
self
,
mlvl_feats
,
img_metas
,
prev_bev
=
None
,
only_bev
=
False
,
test
=
False
):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
"""
bs
,
num_cam
,
_
,
_
,
_
=
mlvl_feats
[
0
].
shape
dtype
=
mlvl_feats
[
0
].
dtype
object_query_embeds
=
None
bev_queries
=
self
.
bev_embedding
.
weight
.
to
(
dtype
)
bev_mask
=
torch
.
zeros
((
bs
,
self
.
bev_h
,
self
.
bev_w
),
device
=
bev_queries
.
device
).
to
(
dtype
)
bev_pos
=
self
.
positional_encoding
(
bev_mask
).
to
(
dtype
)
if
only_bev
:
# only use encoder to obtain BEV features, TODO: refine the workaround
return
self
.
transformer
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
,
)
else
:
outputs
=
self
.
transformer
(
mlvl_feats
,
bev_queries
,
object_query_embeds
,
self
.
bev_h
,
self
.
bev_w
,
grid_length
=
(
self
.
real_h
/
self
.
bev_h
,
self
.
real_w
/
self
.
bev_w
),
bev_pos
=
bev_pos
,
reg_branches
=
None
,
# noqa:E501
cls_branches
=
None
,
img_metas
=
img_metas
,
prev_bev
=
prev_bev
)
bev_embed
,
occ_outs
=
outputs
outs
=
{
'bev_embed'
:
bev_embed
,
'occ'
:
occ_outs
,
}
return
outs
@
force_fp32
(
apply_to
=
(
'preds_dicts'
))
def
loss
(
self
,
# gt_bboxes_list,
# gt_labels_list,
voxel_semantics
,
mask_camera
,
preds_dicts
,
gt_bboxes_ignore
=
None
,
img_metas
=
None
):
loss_dict
=
dict
()
occ
=
preds_dicts
[
'occ'
]
assert
voxel_semantics
.
min
()
>=
0
and
voxel_semantics
.
max
()
<=
17
losses
=
self
.
loss_single
(
voxel_semantics
,
mask_camera
,
occ
)
loss_dict
[
'loss_occ'
]
=
losses
return
loss_dict
def
loss_single
(
self
,
voxel_semantics
,
mask_camera
,
preds
):
voxel_semantics
=
voxel_semantics
.
long
()
if
self
.
use_mask
:
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
,
self
.
num_classes
)
mask_camera
=
mask_camera
.
reshape
(
-
1
)
num_total_samples
=
mask_camera
.
sum
()
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,
mask_camera
,
avg_factor
=
num_total_samples
)
else
:
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
,
self
.
num_classes
)
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,)
return
loss_occ
@
force_fp32
(
apply_to
=
(
'preds'
))
def
get_occ
(
self
,
preds_dicts
,
img_metas
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
predss : occ results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
# return self.transformer.get_occ(
# preds_dicts, img_metas, rescale=rescale)
# print(img_metas[0].keys())
occ_out
=
preds_dicts
[
'occ'
]
occ_score
=
occ_out
.
softmax
(
-
1
)
occ_score
=
occ_score
.
argmax
(
-
1
)
return
occ_score
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/__init__.py
0 → 100644
View file @
df3c64a9
from
.bevformer_occ
import
BEVFormerOcc
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/detectors/bevformer_occ.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Xiaoyu Tian
# ---------------------------------------------
import
torch
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet.models
import
DETECTORS
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
projects.mmdet3d_plugin.models.utils.grid_mask
import
GridMask
import
time
import
copy
import
numpy
as
np
import
mmdet3d
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
@
DETECTORS
.
register_module
()
class
BEVFormerOcc
(
MVXTwoStageDetector
):
"""BEVFormer.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def
__init__
(
self
,
use_grid_mask
=
False
,
pts_voxel_layer
=
None
,
pts_voxel_encoder
=
None
,
pts_middle_encoder
=
None
,
pts_fusion_layer
=
None
,
img_backbone
=
None
,
pts_backbone
=
None
,
img_neck
=
None
,
pts_neck
=
None
,
pts_bbox_head
=
None
,
img_roi_head
=
None
,
img_rpn_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
video_test_mode
=
False
):
super
(
BEVFormerOcc
,
self
).
__init__
(
pts_voxel_layer
,
pts_voxel_encoder
,
pts_middle_encoder
,
pts_fusion_layer
,
img_backbone
,
pts_backbone
,
img_neck
,
pts_neck
,
pts_bbox_head
,
img_roi_head
,
img_rpn_head
,
train_cfg
,
test_cfg
,
pretrained
)
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
self
.
use_grid_mask
=
use_grid_mask
self
.
fp16_enabled
=
False
# temporal
self
.
video_test_mode
=
video_test_mode
self
.
prev_frame_info
=
{
'prev_bev'
:
None
,
'scene_token'
:
None
,
'prev_pos'
:
0
,
'prev_angle'
:
0
,
}
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features of images."""
B
=
img
.
size
(
0
)
if
img
is
not
None
:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if
img
.
dim
()
==
5
and
img
.
size
(
0
)
==
1
:
img
.
squeeze_
()
elif
img
.
dim
()
==
5
and
img
.
size
(
0
)
>
1
:
B
,
N
,
C
,
H
,
W
=
img
.
size
()
img
=
img
.
reshape
(
B
*
N
,
C
,
H
,
W
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
img_feats
=
self
.
img_backbone
(
img
)
if
isinstance
(
img_feats
,
dict
):
img_feats
=
list
(
img_feats
.
values
())
else
:
return
None
if
self
.
with_img_neck
:
img_feats
=
self
.
img_neck
(
img_feats
)
img_feats_reshaped
=
[]
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
if
len_queue
is
not
None
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
else
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
@
auto_fp16
(
apply_to
=
(
'img'
))
def
extract_feat
(
self
,
img
,
img_metas
=
None
,
len_queue
=
None
):
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
return
img_feats
def
forward_pts_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
voxel_semantics
,
mask_camera
,
img_metas
,
gt_bboxes_ignore
=
None
,
prev_bev
=
None
):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs
=
self
.
pts_bbox_head
(
pts_feats
,
img_metas
,
prev_bev
)
loss_inputs
=
[
voxel_semantics
,
mask_camera
,
outs
]
losses
=
self
.
pts_bbox_head
.
loss
(
*
loss_inputs
,
img_metas
=
img_metas
)
return
losses
def
forward_dummy
(
self
,
img
):
dummy_metas
=
None
return
self
.
forward_test
(
img
=
img
,
img_metas
=
[[
dummy_metas
]])
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self
.
eval
()
with
torch
.
no_grad
():
prev_bev
=
None
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
for
i
in
range
(
len_queue
):
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
if
not
img_metas
[
0
][
'prev_bev_exists'
]:
prev_bev
=
None
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats
=
[
each_scale
[:,
i
]
for
each_scale
in
img_feats_list
]
prev_bev
=
self
.
pts_bbox_head
(
img_feats
,
img_metas
,
prev_bev
,
only_bev
=
True
)
self
.
train
()
return
prev_bev
@
auto_fp16
(
apply_to
=
(
'img'
,
'points'
))
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
voxel_semantics
=
None
,
mask_lidar
=
None
,
mask_camera
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
img_depth
=
None
,
img_mask
=
None
,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
len_queue
=
img
.
size
(
1
)
prev_img
=
img
[:,
:
-
1
,
...]
img
=
img
[:,
-
1
,
...]
prev_img_metas
=
copy
.
deepcopy
(
img_metas
)
prev_bev
=
self
.
obtain_history_bev
(
prev_img
,
prev_img_metas
)
img_metas
=
[
each
[
len_queue
-
1
]
for
each
in
img_metas
]
if
not
img_metas
[
0
][
'prev_bev_exists'
]:
prev_bev
=
None
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
voxel_semantics
,
mask_camera
,
img_metas
,
gt_bboxes_ignore
,
prev_bev
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
img_metas
,
img
=
None
,
voxel_semantics
=
None
,
mask_lidar
=
None
,
mask_camera
=
None
,
**
kwargs
):
for
var
,
name
in
[(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
img
=
[
img
]
if
img
is
None
else
img
if
img_metas
[
0
][
0
][
'scene_token'
]
!=
self
.
prev_frame_info
[
'scene_token'
]:
# the first sample of each scene is truncated
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# update idx
self
.
prev_frame_info
[
'scene_token'
]
=
img_metas
[
0
][
0
][
'scene_token'
]
# do not use temporal information
if
not
self
.
video_test_mode
:
self
.
prev_frame_info
[
'prev_bev'
]
=
None
# Get the delta of ego position and angle between two timestamps.
tmp_pos
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][:
3
])
tmp_angle
=
copy
.
deepcopy
(
img_metas
[
0
][
0
][
'can_bus'
][
-
1
])
if
self
.
prev_frame_info
[
'prev_bev'
]
is
not
None
:
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
-=
self
.
prev_frame_info
[
'prev_pos'
]
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
-=
self
.
prev_frame_info
[
'prev_angle'
]
else
:
img_metas
[
0
][
0
][
'can_bus'
][
-
1
]
=
0
img_metas
[
0
][
0
][
'can_bus'
][:
3
]
=
0
new_prev_bev
,
occ_results
=
self
.
simple_test
(
img_metas
[
0
],
img
[
0
],
prev_bev
=
self
.
prev_frame_info
[
'prev_bev'
],
**
kwargs
)
# During inference, we save the BEV features and ego motion of each timestamp.
self
.
prev_frame_info
[
'prev_pos'
]
=
tmp_pos
self
.
prev_frame_info
[
'prev_angle'
]
=
tmp_angle
self
.
prev_frame_info
[
'prev_bev'
]
=
new_prev_bev
return
occ_results
def
simple_test_pts
(
self
,
x
,
img_metas
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function"""
outs
=
self
.
pts_bbox_head
(
x
,
img_metas
,
prev_bev
=
prev_bev
,
test
=
True
)
occ
=
self
.
pts_bbox_head
.
get_occ
(
outs
,
img_metas
,
rescale
=
rescale
)
return
outs
[
'bev_embed'
],
occ
def
simple_test
(
self
,
img_metas
,
img
=
None
,
prev_bev
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
img_metas
)
# bbox_list = [dict() for i in range(len(img_metas))]
new_prev_bev
,
occ
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
prev_bev
,
rescale
=
rescale
)
# for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
# result_dict['pts_bbox'] = pts_bbox
return
new_prev_bev
,
occ
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/__init__.py
0 → 100644
View file @
df3c64a9
from
.custom_hooks
import
TransferWeight
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py
0 → 100644
View file @
df3c64a9
from
mmcv.runner.hooks.hook
import
HOOKS
,
Hook
from
projects.mmdet3d_plugin.models.utils
import
run_time
@
HOOKS
.
register_module
()
class
TransferWeight
(
Hook
):
def
__init__
(
self
,
every_n_inters
=
1
):
self
.
every_n_inters
=
every_n_inters
def
after_train_iter
(
self
,
runner
):
if
self
.
every_n_inner_iters
(
runner
,
self
.
every_n_inters
):
runner
.
eval_model
.
load_state_dict
(
runner
.
model
.
state_dict
())
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/__init__.py
0 → 100644
View file @
df3c64a9
from
.transformer
import
PerceptionTransformer
from
.spatial_cross_attention
import
SpatialCrossAttention
,
MSDeformableAttention3D
from
.temporal_self_attention
import
TemporalSelfAttention
from
.encoder
import
BEVFormerEncoder
,
BEVFormerLayer
from
.decoder
import
DetectionTransformerDecoder
from
.transformer_occ
import
TransformerOcc
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/custom_base_transformer_layer.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
copy
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv
import
ConfigDict
,
deprecated_api_warning
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
build_norm_layer
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
FEEDFORWARD_NETWORK
,
POSITIONAL_ENCODING
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try
:
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
# noqa F401
warnings
.
warn
(
ImportWarning
(
'``MultiScaleDeformableAttention`` has been moved to '
'``mmcv.ops.multi_scale_deform_attn``, please change original path '
# noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` '
# noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` '
# noqa E501
))
except
ImportError
:
warnings
.
warn
(
'Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. '
)
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
,
build_attention
@
TRANSFORMER_LAYER
.
register_module
()
class
MyCustomBaseTransformerLayer
(
BaseModule
):
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def
__init__
(
self
,
attn_cfgs
=
None
,
ffn_cfgs
=
dict
(
type
=
'FFN'
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
ffn_drop
=
0.
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
),
operation_order
=
None
,
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
,
batch_first
=
True
,
**
kwargs
):
deprecated_args
=
dict
(
feedforward_channels
=
'feedforward_channels'
,
ffn_dropout
=
'ffn_drop'
,
ffn_num_fcs
=
'num_fcs'
)
for
ori_name
,
new_name
in
deprecated_args
.
items
():
if
ori_name
in
kwargs
:
warnings
.
warn
(
f
'The arguments `
{
ori_name
}
` in BaseTransformerLayer '
f
'has been deprecated, now you should set `
{
new_name
}
` '
f
'and other FFN related arguments '
f
'to a dict named `ffn_cfgs`. '
)
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
super
(
MyCustomBaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
self
.
batch_first
=
batch_first
assert
set
(
operation_order
)
&
set
(
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
set
(
operation_order
),
f
'The operation_order of'
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'contains all four operation type '
\
f
"
{
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
]
}
"
num_attn
=
operation_order
.
count
(
'self_attn'
)
+
operation_order
.
count
(
'cross_attn'
)
if
isinstance
(
attn_cfgs
,
dict
):
attn_cfgs
=
[
copy
.
deepcopy
(
attn_cfgs
)
for
_
in
range
(
num_attn
)]
else
:
assert
num_attn
==
len
(
attn_cfgs
),
f
'The length '
\
f
'of attn_cfg
{
num_attn
}
is '
\
f
'not consistent with the number of attention'
\
f
'in operation_order
{
operation_order
}
.'
self
.
num_attn
=
num_attn
self
.
operation_order
=
operation_order
self
.
norm_cfg
=
norm_cfg
self
.
pre_norm
=
operation_order
[
0
]
==
'norm'
self
.
attentions
=
ModuleList
()
index
=
0
for
operation_name
in
operation_order
:
if
operation_name
in
[
'self_attn'
,
'cross_attn'
]:
if
'batch_first'
in
attn_cfgs
[
index
]:
assert
self
.
batch_first
==
attn_cfgs
[
index
][
'batch_first'
]
else
:
attn_cfgs
[
index
][
'batch_first'
]
=
self
.
batch_first
attention
=
build_attention
(
attn_cfgs
[
index
])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention
.
operation_name
=
operation_name
self
.
attentions
.
append
(
attention
)
index
+=
1
self
.
embed_dims
=
self
.
attentions
[
0
].
embed_dims
self
.
ffns
=
ModuleList
()
num_ffns
=
operation_order
.
count
(
'ffn'
)
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
ConfigDict
(
ffn_cfgs
)
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
[
copy
.
deepcopy
(
ffn_cfgs
)
for
_
in
range
(
num_ffns
)]
assert
len
(
ffn_cfgs
)
==
num_ffns
for
ffn_index
in
range
(
num_ffns
):
if
'embed_dims'
not
in
ffn_cfgs
[
ffn_index
]:
ffn_cfgs
[
'embed_dims'
]
=
self
.
embed_dims
else
:
# print()
# print('ffn_cfgs ',ffn_cfgs[ffn_index]['embed_dims'] ,self.embed_dims)
assert
ffn_cfgs
[
ffn_index
][
'embed_dims'
]
==
self
.
embed_dims
self
.
ffns
.
append
(
build_feedforward_network
(
ffn_cfgs
[
ffn_index
]))
self
.
norms
=
ModuleList
()
num_norms
=
operation_order
.
count
(
'norm'
)
for
_
in
range
(
num_norms
):
self
.
norms
.
append
(
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
if
layer
==
'self_attn'
:
temp_key
=
temp_value
=
query
query
=
self
.
attentions
[
attn_index
](
query
,
temp_key
,
temp_value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
query_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/decoder.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
mmcv
import
cv2
as
cv
import
copy
import
warnings
from
matplotlib
import
pyplot
as
plt
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
import
math
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
DetectionTransformerDecoder
(
TransformerLayerSequence
):
"""Implements the decoder in DETR3D transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
**
kwargs
):
super
(
DetectionTransformerDecoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
fp16_enabled
=
False
def
forward
(
self
,
query
,
*
args
,
reference_points
=
None
,
reg_branches
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
query
intermediate
=
[]
intermediate_reference_points
=
[]
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points_input
=
reference_points
[...,
:
2
].
unsqueeze
(
2
)
# BS NUM_QUERY NUM_LEVEL 2
output
=
layer
(
output
,
*
args
,
reference_points
=
reference_points_input
,
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
output
=
output
.
permute
(
1
,
0
,
2
)
if
reg_branches
is
not
None
:
tmp
=
reg_branches
[
lid
](
output
)
assert
reference_points
.
shape
[
-
1
]
==
3
new_reference_points
=
torch
.
zeros_like
(
reference_points
)
new_reference_points
[...,
:
2
]
=
tmp
[
...,
:
2
]
+
inverse_sigmoid
(
reference_points
[...,
:
2
])
new_reference_points
[...,
2
:
3
]
=
tmp
[
...,
4
:
5
]
+
inverse_sigmoid
(
reference_points
[...,
2
:
3
])
new_reference_points
=
new_reference_points
.
sigmoid
()
reference_points
=
new_reference_points
.
detach
()
output
=
output
.
permute
(
1
,
0
,
2
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
intermediate_reference_points
)
return
output
,
reference_points
@
ATTENTION
.
register_module
()
class
CustomMSDeformableAttention
(
BaseModule
):
"""An attention module used in Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
False
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
@
deprecated_api_warning
({
'residual'
:
'identity'
},
cls_name
=
'MultiScaleDeformableAttention'
)
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
flag
=
'decoder'
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
# using fp16 deformable attention is unstable because it performs many sum operations
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/encoder.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
.custom_base_transformer_layer
import
MyCustomBaseTransformerLayer
import
copy
import
warnings
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
from
mmcv.runner
import
force_fp32
,
auto_fp16
import
numpy
as
np
import
torch
import
cv2
as
cv
import
mmcv
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BEVFormerEncoder
(
TransformerLayerSequence
):
"""
Attention with both self and cross
Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def
__init__
(
self
,
*
args
,
pc_range
=
None
,
num_points_in_pillar
=
4
,
return_intermediate
=
False
,
dataset_type
=
'nuscenes'
,
**
kwargs
):
super
(
BEVFormerEncoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
self
.
num_points_in_pillar
=
num_points_in_pillar
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
@
staticmethod
def
get_reference_points
(
H
,
W
,
Z
=
8
,
num_points_in_pillar
=
4
,
dim
=
'3d'
,
bs
=
1
,
device
=
'cuda'
,
dtype
=
torch
.
float
):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of bev.
Z: hight of pillar.
D: sample D points uniformly from each pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
if
dim
==
'3d'
:
zs
=
torch
.
linspace
(
0.5
,
Z
-
0.5
,
num_points_in_pillar
,
dtype
=
dtype
,
device
=
device
).
view
(
-
1
,
1
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
Z
xs
=
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
1
,
W
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
W
ys
=
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
H
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
H
ref_3d
=
torch
.
stack
((
xs
,
ys
,
zs
),
-
1
)
ref_3d
=
ref_3d
.
permute
(
0
,
3
,
1
,
2
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
ref_3d
=
ref_3d
[
None
].
repeat
(
bs
,
1
,
1
,
1
)
#shape: (bs,num_points_in_pillar,h*w,3)
return
ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif
dim
==
'2d'
:
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
),
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
)
)
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
H
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
W
ref_2d
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
ref_2d
=
ref_2d
.
repeat
(
bs
,
1
,
1
).
unsqueeze
(
2
)
return
ref_2d
# This function must use fp32!!!
@
force_fp32
(
apply_to
=
(
'reference_points'
,
'img_metas'
))
def
point_sampling
(
self
,
reference_points
,
pc_range
,
img_metas
):
ego2lidar
=
img_metas
[
0
][
'ego2lidar'
]
lidar2img
=
[]
for
img_meta
in
img_metas
:
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
lidar2img
=
np
.
asarray
(
lidar2img
)
lidar2img
=
reference_points
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
ego2lidar
=
reference_points
.
new_tensor
(
ego2lidar
)
# ego2lidar = ego2lidar.unsqueeze(dim=0).repeat(num_imgs,1,1).unsqueeze(0)
reference_points
=
reference_points
.
clone
()
reference_points
[...,
0
:
1
]
=
reference_points
[...,
0
:
1
]
*
\
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
]
reference_points
[...,
1
:
2
]
=
reference_points
[...,
1
:
2
]
*
\
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
]
reference_points
[...,
2
:
3
]
=
reference_points
[...,
2
:
3
]
*
\
(
pc_range
[
5
]
-
pc_range
[
2
])
+
pc_range
[
2
]
reference_points
=
torch
.
cat
(
(
reference_points
,
torch
.
ones_like
(
reference_points
[...,
:
1
])),
-
1
)
reference_points
=
reference_points
.
permute
(
1
,
0
,
2
,
3
)
#shape: (num_points_in_pillar,bs,h*w,4)
D
,
B
,
num_query
=
reference_points
.
size
()[:
3
]
# D=num_points_in_pillar , num_query=h*w
num_cam
=
lidar2img
.
size
(
1
)
reference_points
=
reference_points
.
view
(
D
,
B
,
1
,
num_query
,
4
).
repeat
(
1
,
1
,
num_cam
,
1
,
1
).
unsqueeze
(
-
1
)
#shape: (num_points_in_pillar,bs,num_cam,h*w,4)
lidar2img
=
lidar2img
.
view
(
1
,
B
,
num_cam
,
1
,
4
,
4
).
repeat
(
D
,
1
,
1
,
num_query
,
1
,
1
)
ego2lidar
=
ego2lidar
.
view
(
1
,
1
,
1
,
1
,
4
,
4
).
repeat
(
D
,
1
,
num_cam
,
num_query
,
1
,
1
)
reference_points_cam
=
torch
.
matmul
(
torch
.
matmul
(
lidar2img
.
to
(
torch
.
float32
),
ego2lidar
.
to
(
torch
.
float32
)),
reference_points
.
to
(
torch
.
float32
)).
squeeze
(
-
1
)
eps
=
1e-5
bev_mask
=
(
reference_points_cam
[...,
2
:
3
]
>
eps
)
reference_points_cam
=
reference_points_cam
[...,
0
:
2
]
/
torch
.
maximum
(
reference_points_cam
[...,
2
:
3
],
torch
.
ones_like
(
reference_points_cam
[...,
2
:
3
])
*
eps
)
reference_points_cam
[...,
0
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
1
]
reference_points_cam
[...,
1
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
0
]
bev_mask
=
(
bev_mask
&
(
reference_points_cam
[...,
1
:
2
]
>
0.0
)
&
(
reference_points_cam
[...,
1
:
2
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
>
0.0
))
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
bev_mask
=
torch
.
nan_to_num
(
bev_mask
)
else
:
bev_mask
=
bev_mask
.
new_tensor
(
np
.
nan_to_num
(
bev_mask
.
cpu
().
numpy
()))
reference_points_cam
=
reference_points_cam
.
permute
(
2
,
1
,
3
,
0
,
4
)
#shape: (num_cam,bs,h*w,num_points_in_pillar,2)
bev_mask
=
bev_mask
.
permute
(
2
,
1
,
3
,
0
,
4
).
squeeze
(
-
1
)
return
reference_points_cam
,
bev_mask
@
auto_fp16
()
def
forward
(
self
,
bev_query
,
key
,
value
,
*
args
,
bev_h
=
None
,
bev_w
=
None
,
bev_pos
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
valid_ratios
=
None
,
prev_bev
=
None
,
shift
=
0.
,
**
kwargs
):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output
=
bev_query
intermediate
=
[]
ref_3d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
],
self
.
num_points_in_pillar
,
dim
=
'3d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
ref_2d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
dim
=
'2d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
reference_points_cam
,
bev_mask
=
self
.
point_sampling
(
ref_3d
,
self
.
pc_range
,
kwargs
[
'img_metas'
])
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d
=
ref_2d
# .clone()
shift_ref_2d
+=
shift
[:,
None
,
None
,
:]
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query
=
bev_query
.
permute
(
1
,
0
,
2
)
bev_pos
=
bev_pos
.
permute
(
1
,
0
,
2
)
bs
,
len_bev
,
num_bev_level
,
_
=
ref_2d
.
shape
if
prev_bev
is
not
None
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
prev_bev
=
torch
.
stack
(
[
prev_bev
,
bev_query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
-
1
)
hybird_ref_2d
=
torch
.
stack
([
shift_ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
else
:
hybird_ref_2d
=
torch
.
stack
([
ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
for
lid
,
layer
in
enumerate
(
self
.
layers
):
output
=
layer
(
bev_query
,
key
,
value
,
*
args
,
bev_pos
=
bev_pos
,
ref_2d
=
hybird_ref_2d
,
ref_3d
=
ref_3d
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
reference_points_cam
=
reference_points_cam
,
bev_mask
=
bev_mask
,
prev_bev
=
prev_bev
,
**
kwargs
)
bev_query
=
output
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
)
return
output
@
TRANSFORMER_LAYER
.
register_module
()
class
BEVFormerLayer
(
MyCustomBaseTransformerLayer
):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def
__init__
(
self
,
attn_cfgs
,
feedforward_channels
,
ffn_dropout
=
0.0
,
operation_order
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'LN'
),
ffn_num_fcs
=
2
,
**
kwargs
):
super
(
BEVFormerLayer
,
self
).
__init__
(
attn_cfgs
=
attn_cfgs
,
feedforward_channels
=
feedforward_channels
,
ffn_dropout
=
ffn_dropout
,
operation_order
=
operation_order
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
ffn_num_fcs
=
ffn_num_fcs
,
**
kwargs
)
self
.
fp16_enabled
=
False
assert
len
(
operation_order
)
==
6
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
bev_pos
=
None
,
query_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
ref_2d
=
None
,
ref_3d
=
None
,
bev_h
=
None
,
bev_w
=
None
,
reference_points_cam
=
None
,
mask
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
ffn_index
=
0
identity
=
query
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
# temporal self attention
if
layer
==
'self_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
prev_bev
,
prev_bev
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
bev_pos
,
key_pos
=
bev_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
reference_points
=
ref_2d
,
spatial_shapes
=
torch
.
tensor
(
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
# spaital cross attention
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
key
,
value
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
reference_points
=
ref_3d
,
reference_points_cam
=
reference_points_cam
,
mask
=
mask
,
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
**
kwargs
)
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/multi_scale_deformable_attn_function.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
torch
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
class
MultiScaleDeformableAttnFunction_fp16
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
class
MultiScaleDeformableAttnFunction_fp32
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
=
ctx
.
im2col_step
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value
,
value_spatial_shapes
,
value_level_start_index
,
\
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_value
=
torch
.
zeros_like
(
value
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
.
contiguous
(),
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
build_attention
import
math
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
ATTENTION
.
register_module
()
class
SpatialCrossAttention
(
BaseModule
):
"""An attention module used in BEVFormer.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_cams (int): The number of cameras
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
deformable_attention: (dict): The config for the deformable attention used in SCA.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_cams
=
6
,
pc_range
=
None
,
dropout
=
0.1
,
init_cfg
=
None
,
batch_first
=
False
,
deformable_attention
=
dict
(
type
=
'MSDeformableAttention3D'
,
embed_dims
=
256
,
num_levels
=
4
),
**
kwargs
):
super
(
SpatialCrossAttention
,
self
).
__init__
(
init_cfg
)
self
.
init_cfg
=
init_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
self
.
deformable_attention
=
build_attention
(
deformable_attention
)
self
.
embed_dims
=
embed_dims
self
.
num_cams
=
num_cams
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
batch_first
=
batch_first
self
.
init_weight
()
def
init_weight
(
self
):
"""Default initialization for Parameters of Module."""
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
@
force_fp32
(
apply_to
=
(
'query'
,
'key'
,
'value'
,
'query_pos'
,
'reference_points_cam'
))
def
forward
(
self
,
query
,
key
,
value
,
residual
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
reference_points_cam
=
None
,
bev_mask
=
None
,
level_start_index
=
None
,
flag
=
'encoder'
,
**
kwargs
):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. (B, N, C, H, W)
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, 4),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
key
is
None
:
key
=
query
if
value
is
None
:
value
=
key
if
residual
is
None
:
inp_residual
=
query
slots
=
torch
.
zeros_like
(
query
)
if
query_pos
is
not
None
:
query
=
query
+
query_pos
bs
,
num_query
,
_
=
query
.
size
()
# bevformer reference_points_cam shape: (num_cam,bs,h*w,num_points_in_pillar,2)
D
=
reference_points_cam
.
size
(
3
)
indexes
=
[]
for
i
,
mask_per_img
in
enumerate
(
bev_mask
):
index_query_per_img
=
mask_per_img
[
0
].
sum
(
-
1
).
nonzero
().
squeeze
(
-
1
)
indexes
.
append
(
index_query_per_img
)
max_len
=
max
([
len
(
each
)
for
each
in
indexes
])
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch
=
query
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
])
reference_points_rebatch
=
reference_points_cam
.
new_zeros
(
[
bs
,
self
.
num_cams
,
max_len
,
D
,
2
])
for
j
in
range
(
bs
):
for
i
,
reference_points_per_img
in
enumerate
(
reference_points_cam
):
index_query_per_img
=
indexes
[
i
]
queries_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
query
[
j
,
index_query_per_img
]
reference_points_rebatch
[
j
,
i
,
:
len
(
index_query_per_img
)]
=
reference_points_per_img
[
j
,
index_query_per_img
]
num_cams
,
l
,
bs
,
embed_dims
=
key
.
shape
key
=
key
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
value
=
value
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
bs
*
self
.
num_cams
,
l
,
self
.
embed_dims
)
queries
=
self
.
deformable_attention
(
query
=
queries_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
self
.
embed_dims
),
key
=
key
,
value
=
value
,
reference_points
=
reference_points_rebatch
.
view
(
bs
*
self
.
num_cams
,
max_len
,
D
,
2
),
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
).
view
(
bs
,
self
.
num_cams
,
max_len
,
self
.
embed_dims
)
for
j
in
range
(
bs
):
for
i
,
index_query_per_img
in
enumerate
(
indexes
):
slots
[
j
,
index_query_per_img
]
+=
queries
[
j
,
i
,
:
len
(
index_query_per_img
)]
count
=
bev_mask
.
sum
(
-
1
)
>
0
count
=
count
.
permute
(
1
,
2
,
0
).
sum
(
-
1
)
count
=
torch
.
clamp
(
count
,
min
=
1.0
)
slots
=
slots
/
count
[...,
None
]
slots
=
self
.
output_proj
(
slots
)
return
self
.
dropout
(
slots
)
+
inp_residual
@
ATTENTION
.
register_module
()
class
MSDeformableAttention3D
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
8
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
batch_first
=
batch_first
self
.
output_proj
=
None
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
value
=
query
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
"""
For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
For each referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points.
"""
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
bs
,
num_query
,
num_Z_anchors
,
xy
=
reference_points
.
shape
reference_points
=
reference_points
[:,
:,
None
,
None
,
None
,
:,
:]
sampling_offsets
=
sampling_offsets
/
\
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
=
sampling_offsets
.
shape
sampling_offsets
=
sampling_offsets
.
view
(
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
//
num_Z_anchors
,
num_Z_anchors
,
xy
)
sampling_locations
=
reference_points
+
sampling_offsets
bs
,
num_query
,
num_heads
,
num_levels
,
num_points
,
num_Z_anchors
,
xy
=
sampling_locations
.
shape
assert
num_all_points
==
num_points
*
num_Z_anchors
sampling_locations
=
sampling_locations
.
view
(
bs
,
num_query
,
num_heads
,
num_levels
,
num_all_points
,
xy
)
elif
reference_points
.
shape
[
-
1
]
==
4
:
assert
False
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2
# attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
#
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
output
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
ATTENTION
import
math
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
from
mmcv.utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
ATTENTION
.
register_module
()
class
TemporalSelfAttention
(
BaseModule
):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to True.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV.
the length of BEV queue is 2.
"""
def
__init__
(
self
,
embed_dims
=
256
,
num_heads
=
8
,
num_levels
=
4
,
num_points
=
4
,
num_bev_queue
=
2
,
im2col_step
=
64
,
dropout
=
0.1
,
batch_first
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
embed_dims
%
num_heads
!=
0
:
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
dim_per_head
=
embed_dims
//
num_heads
self
.
norm_cfg
=
norm_cfg
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
batch_first
=
batch_first
self
.
fp16_enabled
=
False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
if
not
_is_power_of_2
(
dim_per_head
):
warnings
.
warn
(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.'
)
self
.
im2col_step
=
im2col_step
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_heads
=
num_heads
self
.
num_points
=
num_points
self
.
num_bev_queue
=
num_bev_queue
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
*
self
.
num_bev_queue
,
num_bev_queue
*
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
*
self
.
num_bev_queue
,
num_bev_queue
*
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
init_weights
()
def
init_weights
(
self
):
"""Default initialization for Parameters of Module."""
constant_init
(
self
.
sampling_offsets
,
0.
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
self
.
num_heads
,
1
,
1
,
2
).
repeat
(
1
,
self
.
num_levels
*
self
.
num_bev_queue
,
self
.
num_points
,
1
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
self
.
_is_init
=
True
def
forward
(
self
,
query
,
key
=
None
,
value
=
None
,
identity
=
None
,
query_pos
=
None
,
key_padding_mask
=
None
,
reference_points
=
None
,
spatial_shapes
=
None
,
level_start_index
=
None
,
flag
=
'decoder'
,
**
kwargs
):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if
value
is
None
:
assert
self
.
batch_first
bs
,
len_bev
,
c
=
query
.
shape
value
=
torch
.
stack
([
query
,
query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
c
)
# value = torch.cat([query, query], 0)
if
identity
is
None
:
identity
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
embed_dims
=
query
.
shape
_
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
assert
self
.
num_bev_queue
==
2
query
=
torch
.
cat
([
value
[:
bs
],
query
],
-
1
)
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
value
=
value
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
)
sampling_offsets
=
sampling_offsets
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_bev_queue
,
self
.
num_levels
,
self
.
num_points
)
attention_weights
=
attention_weights
.
permute
(
0
,
3
,
1
,
2
,
4
,
5
)
\
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
).
contiguous
()
sampling_offsets
=
sampling_offsets
.
permute
(
0
,
3
,
1
,
2
,
4
,
5
,
6
)
\
.
reshape
(
bs
*
self
.
num_bev_queue
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
+
sampling_offsets
\
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
+
sampling_offsets
/
self
.
num_points
\
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
else
:
raise
ValueError
(
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
# using fp16 deformable attention is unstable because it performs many sum operations
if
value
.
dtype
==
torch
.
float16
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
# output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
output
=
output
.
permute
(
1
,
2
,
0
)
# fuse history value and current value
# (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
output
=
output
.
view
(
num_query
,
embed_dims
,
bs
,
self
.
num_bev_queue
)
output
=
output
.
mean
(
-
1
)
# (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
output
=
output
.
permute
(
2
,
0
,
1
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
from
torch.nn.init
import
normal_
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.runner.base_module
import
BaseModule
from
torchvision.transforms.functional
import
rotate
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
mmcv.runner
import
force_fp32
,
auto_fp16
@
TRANSFORMER
.
register_module
()
class
PerceptionTransformer
(
BaseModule
):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
decoder
=
None
,
embed_dims
=
256
,
rotate_prev_bev
=
True
,
use_shift
=
True
,
use_can_bus
=
True
,
can_bus_norm
=
True
,
use_cams_embeds
=
True
,
rotate_center
=
[
100
,
100
],
**
kwargs
):
super
(
PerceptionTransformer
,
self
).
__init__
(
**
kwargs
)
self
.
encoder
=
build_transformer_layer_sequence
(
encoder
)
self
.
decoder
=
build_transformer_layer_sequence
(
decoder
)
self
.
embed_dims
=
embed_dims
self
.
num_feature_levels
=
num_feature_levels
self
.
num_cams
=
num_cams
self
.
fp16_enabled
=
False
self
.
rotate_prev_bev
=
rotate_prev_bev
self
.
use_shift
=
use_shift
self
.
use_can_bus
=
use_can_bus
self
.
can_bus_norm
=
can_bus_norm
self
.
use_cams_embeds
=
use_cams_embeds
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
init_layers
()
self
.
rotate_center
=
rotate_center
def
init_layers
(
self
):
"""Initialize layers of the Detr3DTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
cams_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_cams
,
self
.
embed_dims
))
self
.
reference_points
=
nn
.
Linear
(
self
.
embed_dims
,
3
)
self
.
can_bus_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
18
,
self
.
embed_dims
//
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
embed_dims
//
2
,
self
.
embed_dims
),
nn
.
ReLU
(
inplace
=
True
),
)
if
self
.
can_bus_norm
:
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
normal_
(
self
.
level_embeds
)
normal_
(
self
.
cams_embeds
)
xavier_init
(
self
.
reference_points
,
distribution
=
'uniform'
,
bias
=
0.
)
xavier_init
(
self
.
can_bus_mlp
,
distribution
=
'uniform'
,
bias
=
0.
)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'prev_bev'
,
'bev_pos'
))
def
get_bev_features
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""
obtain bev features.
"""
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_queries
=
bev_queries
.
unsqueeze
(
1
).
repeat
(
1
,
bs
,
1
)
bev_pos
=
bev_pos
.
flatten
(
2
).
permute
(
2
,
0
,
1
)
# obtain rotation angle and shift with ego motion
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
for
each
in
kwargs
[
'img_metas'
]])
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
for
each
in
kwargs
[
'img_metas'
]])
ego_angle
=
np
.
array
(
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
kwargs
[
'img_metas'
]])
grid_length_y
=
grid_length
[
0
]
grid_length_x
=
grid_length
[
1
]
translation_length
=
np
.
sqrt
(
delta_x
**
2
+
delta_y
**
2
)
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
bev_angle
=
ego_angle
-
translation_angle
shift_y
=
translation_length
*
\
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
bev_h
shift_x
=
translation_length
*
\
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
bev_w
shift_y
=
shift_y
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift
=
bev_queries
.
new_tensor
(
[
shift_x
,
shift_y
]).
permute
(
1
,
0
)
# xy, bs -> bs, xy
if
prev_bev
is
not
None
:
if
prev_bev
.
shape
[
1
]
==
bev_h
*
bev_w
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
if
self
.
rotate_prev_bev
:
for
i
in
range
(
bs
):
# num_prev_bev = prev_bev.size(1)
rotation_angle
=
kwargs
[
'img_metas'
][
i
][
'can_bus'
][
-
1
]
tmp_prev_bev
=
prev_bev
[:,
i
].
reshape
(
bev_h
,
bev_w
,
-
1
).
permute
(
2
,
0
,
1
)
tmp_prev_bev
=
rotate
(
tmp_prev_bev
,
rotation_angle
,
center
=
self
.
rotate_center
)
tmp_prev_bev
=
tmp_prev_bev
.
permute
(
1
,
2
,
0
).
reshape
(
bev_h
*
bev_w
,
1
,
-
1
)
prev_bev
[:,
i
]
=
tmp_prev_bev
[:,
0
]
# add can bus signals
can_bus
=
bev_queries
.
new_tensor
(
[
each
[
'can_bus'
]
for
each
in
kwargs
[
'img_metas'
]])
# [:, :]
can_bus
=
self
.
can_bus_mlp
(
can_bus
)[
None
,
:,
:]
bev_queries
=
bev_queries
+
can_bus
*
self
.
use_can_bus
feat_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
feat
in
enumerate
(
mlvl_feats
):
bs
,
num_cam
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
feat
=
feat
.
flatten
(
3
).
permute
(
1
,
0
,
3
,
2
)
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
2
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
bev_pos
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
feat_flatten
=
feat_flatten
.
permute
(
0
,
2
,
1
,
3
)
# (num_cam, H*W, bs, embed_dims)
bev_embed
=
self
.
encoder
(
bev_queries
,
feat_flatten
,
feat_flatten
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
prev_bev
,
shift
=
shift
,
**
kwargs
)
return
bev_embed
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'object_query_embed'
,
'prev_bev'
,
'bev_pos'
))
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
object_query_embed
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
bev_embed
=
self
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
grid_length
,
bev_pos
=
bev_pos
,
prev_bev
=
prev_bev
,
**
kwargs
)
# bev_embed shape: bs, bev_h*bev_w, embed_dims
bs
=
mlvl_feats
[
0
].
size
(
0
)
query_pos
,
query
=
torch
.
split
(
object_query_embed
,
self
.
embed_dims
,
dim
=
1
)
query_pos
=
query_pos
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
query
=
query
.
unsqueeze
(
0
).
expand
(
bs
,
-
1
,
-
1
)
reference_points
=
self
.
reference_points
(
query_pos
)
reference_points
=
reference_points
.
sigmoid
()
init_reference_out
=
reference_points
query
=
query
.
permute
(
1
,
0
,
2
)
query_pos
=
query_pos
.
permute
(
1
,
0
,
2
)
bev_embed
=
bev_embed
.
permute
(
1
,
0
,
2
)
inter_states
,
inter_references
=
self
.
decoder
(
query
=
query
,
key
=
None
,
value
=
bev_embed
,
query_pos
=
query_pos
,
reference_points
=
reference_points
,
reg_branches
=
reg_branches
,
cls_branches
=
cls_branches
,
spatial_shapes
=
torch
.
tensor
([[
bev_h
,
bev_w
]],
device
=
query
.
device
),
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
**
kwargs
)
inter_references_out
=
inter_references
return
bev_embed
,
inter_states
,
init_reference_out
,
inter_references_out
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/modules/transformer_occ.py
0 → 100644
View file @
df3c64a9
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Xiaoyu Tian
# ---------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
from
mmcv.runner.base_module
import
BaseModule
from
mmdet.models.utils.builder
import
TRANSFORMER
from
torch.nn.init
import
normal_
from
projects.mmdet3d_plugin.models.utils.visual
import
save_tensor
from
mmcv.runner.base_module
import
BaseModule
from
torchvision.transforms.functional
import
rotate
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
projects.mmdet3d_plugin.models.utils.bricks
import
run_time
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.cnn
import
PLUGIN_LAYERS
,
Conv2d
,
Conv3d
,
ConvModule
,
caffe2_xavier_init
@
TRANSFORMER
.
register_module
()
class
TransformerOcc
(
BaseModule
):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def
__init__
(
self
,
num_feature_levels
=
4
,
num_cams
=
6
,
two_stage_num_proposals
=
300
,
encoder
=
None
,
decoder
=
None
,
embed_dims
=
256
,
rotate_prev_bev
=
True
,
use_shift
=
True
,
use_can_bus
=
True
,
can_bus_norm
=
True
,
use_cams_embeds
=
True
,
use_3d
=
False
,
use_conv
=
False
,
rotate_center
=
[
100
,
100
],
num_classes
=
18
,
out_dim
=
32
,
pillar_h
=
16
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
norm_cfg
=
dict
(
type
=
'BN'
,
),
norm_cfg_3d
=
dict
(
type
=
'BN3d'
,
),
**
kwargs
):
super
(
TransformerOcc
,
self
).
__init__
(
**
kwargs
)
self
.
encoder
=
build_transformer_layer_sequence
(
encoder
)
self
.
embed_dims
=
embed_dims
self
.
num_feature_levels
=
num_feature_levels
self
.
num_cams
=
num_cams
self
.
fp16_enabled
=
False
self
.
rotate_prev_bev
=
rotate_prev_bev
self
.
use_shift
=
use_shift
self
.
use_can_bus
=
use_can_bus
self
.
can_bus_norm
=
can_bus_norm
self
.
use_cams_embeds
=
use_cams_embeds
self
.
use_3d
=
use_3d
self
.
use_conv
=
use_conv
self
.
pillar_h
=
pillar_h
self
.
out_dim
=
out_dim
if
not
use_3d
:
if
use_conv
:
use_bias
=
norm_cfg
is
None
self
.
decoder
=
nn
.
Sequential
(
ConvModule
(
self
.
embed_dims
,
self
.
embed_dims
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
),
ConvModule
(
self
.
embed_dims
,
self
.
embed_dims
*
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
),)
else
:
self
.
decoder
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
embed_dims
*
2
),
)
else
:
use_bias_3d
=
norm_cfg_3d
is
None
self
.
middle_dims
=
self
.
embed_dims
//
pillar_h
self
.
decoder
=
nn
.
Sequential
(
ConvModule
(
self
.
middle_dims
,
self
.
out_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias_3d
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
norm_cfg_3d
,
act_cfg
=
act_cfg
),
ConvModule
(
self
.
out_dim
,
self
.
out_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
use_bias_3d
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
norm_cfg_3d
,
act_cfg
=
act_cfg
),
)
self
.
predicter
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_dim
,
self
.
out_dim
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
out_dim
*
2
,
num_classes
),
)
self
.
two_stage_num_proposals
=
two_stage_num_proposals
self
.
init_layers
()
self
.
rotate_center
=
rotate_center
def
init_layers
(
self
):
"""Initialize layers of the Detr3DTransformer."""
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
cams_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_cams
,
self
.
embed_dims
))
# self.reference_points = nn.Linear(self.embed_dims, 3)
self
.
can_bus_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
18
,
self
.
embed_dims
//
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
embed_dims
//
2
,
self
.
embed_dims
),
nn
.
ReLU
(
inplace
=
True
),
)
if
self
.
can_bus_norm
:
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
def
init_weights
(
self
):
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformableAttention3D
)
or
isinstance
(
m
,
TemporalSelfAttention
)
\
or
isinstance
(
m
,
CustomMSDeformableAttention
):
try
:
m
.
init_weight
()
except
AttributeError
:
m
.
init_weights
()
normal_
(
self
.
level_embeds
)
normal_
(
self
.
cams_embeds
)
# xavier_init(self.reference_points, distribution='uniform', bias=0.)
xavier_init
(
self
.
can_bus_mlp
,
distribution
=
'uniform'
,
bias
=
0.
)
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'prev_bev'
,
'bev_pos'
))
def
get_bev_features
(
self
,
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""
obtain bev features.
"""
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_queries
=
bev_queries
.
unsqueeze
(
1
).
repeat
(
1
,
bs
,
1
)
bev_pos
=
bev_pos
.
flatten
(
2
).
permute
(
2
,
0
,
1
)
# obtain rotation angle and shift with ego motion
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
for
each
in
kwargs
[
'img_metas'
]])
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
for
each
in
kwargs
[
'img_metas'
]])
ego_angle
=
np
.
array
(
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
kwargs
[
'img_metas'
]])
grid_length_y
=
grid_length
[
0
]
grid_length_x
=
grid_length
[
1
]
translation_length
=
np
.
sqrt
(
delta_x
**
2
+
delta_y
**
2
)
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
bev_angle
=
ego_angle
-
translation_angle
shift_y
=
translation_length
*
\
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
bev_h
shift_x
=
translation_length
*
\
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
bev_w
shift_y
=
shift_y
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift
=
bev_queries
.
new_tensor
(
[
shift_x
,
shift_y
]).
permute
(
1
,
0
)
# xy, bs -> bs, xy
if
prev_bev
is
not
None
:
if
prev_bev
.
shape
[
1
]
==
bev_h
*
bev_w
:
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
elif
len
(
prev_bev
.
shape
)
==
4
:
prev_bev
=
prev_bev
.
view
(
bs
,
-
1
,
bev_h
*
bev_w
).
permute
(
2
,
0
,
1
)
if
self
.
rotate_prev_bev
:
for
i
in
range
(
bs
):
# num_prev_bev = prev_bev.size(1)
rotation_angle
=
kwargs
[
'img_metas'
][
i
][
'can_bus'
][
-
1
]
tmp_prev_bev
=
prev_bev
[:,
i
].
reshape
(
bev_h
,
bev_w
,
-
1
).
permute
(
2
,
0
,
1
)
tmp_prev_bev
=
rotate
(
tmp_prev_bev
,
rotation_angle
,
center
=
self
.
rotate_center
)
tmp_prev_bev
=
tmp_prev_bev
.
permute
(
1
,
2
,
0
).
reshape
(
bev_h
*
bev_w
,
1
,
-
1
)
prev_bev
[:,
i
]
=
tmp_prev_bev
[:,
0
]
# add can bus signals
can_bus
=
bev_queries
.
new_tensor
(
[
each
[
'can_bus'
]
for
each
in
kwargs
[
'img_metas'
]])
# [:, :]
can_bus
=
self
.
can_bus_mlp
(
can_bus
)[
None
,
:,
:]
bev_queries
=
bev_queries
+
can_bus
*
self
.
use_can_bus
feat_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
feat
in
enumerate
(
mlvl_feats
):
bs
,
num_cam
,
c
,
h
,
w
=
feat
.
shape
spatial_shape
=
(
h
,
w
)
feat
=
feat
.
flatten
(
3
).
permute
(
1
,
0
,
3
,
2
)
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
=
torch
.
cat
(
feat_flatten
,
2
)
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
bev_pos
.
device
)
level_start_index
=
torch
.
cat
((
spatial_shapes
.
new_zeros
(
(
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
feat_flatten
=
feat_flatten
.
permute
(
0
,
2
,
1
,
3
)
# (num_cam, H*W, bs, embed_dims)
bev_embed
=
self
.
encoder
(
bev_queries
,
feat_flatten
,
feat_flatten
,
bev_h
=
bev_h
,
bev_w
=
bev_w
,
bev_pos
=
bev_pos
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
prev_bev
=
prev_bev
,
shift
=
shift
,
**
kwargs
)
return
bev_embed
@
auto_fp16
(
apply_to
=
(
'mlvl_feats'
,
'bev_queries'
,
'object_query_embed'
,
'prev_bev'
,
'bev_pos'
))
def
forward
(
self
,
mlvl_feats
,
bev_queries
,
object_query_embed
,
bev_h
,
bev_w
,
grid_length
=
[
0.512
,
0.512
],
bev_pos
=
None
,
reg_branches
=
None
,
cls_branches
=
None
,
prev_bev
=
None
,
**
kwargs
):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
bev_embed
=
self
.
get_bev_features
(
mlvl_feats
,
bev_queries
,
bev_h
,
bev_w
,
grid_length
=
grid_length
,
bev_pos
=
bev_pos
,
prev_bev
=
prev_bev
,
**
kwargs
)
# bev_embed shape: bs, bev_h*bev_w, embed_dims
bs
=
mlvl_feats
[
0
].
size
(
0
)
bev_embed
=
bev_embed
.
permute
(
0
,
2
,
1
).
view
(
bs
,
-
1
,
bev_h
,
bev_w
)
if
self
.
use_3d
:
outputs
=
self
.
decoder
(
bev_embed
.
view
(
bs
,
-
1
,
self
.
pillar_h
,
bev_h
,
bev_w
))
outputs
=
outputs
.
permute
(
0
,
4
,
3
,
2
,
1
)
elif
self
.
use_conv
:
outputs
=
self
.
decoder
(
bev_embed
)
outputs
=
outputs
.
view
(
bs
,
-
1
,
self
.
pillar_h
,
bev_h
,
bev_w
).
permute
(
0
,
3
,
4
,
2
,
1
)
else
:
outputs
=
self
.
decoder
(
bev_embed
.
permute
(
0
,
2
,
3
,
1
))
outputs
=
outputs
.
view
(
bs
,
bev_h
,
bev_w
,
self
.
pillar_h
,
self
.
out_dim
)
outputs
=
self
.
predicter
(
outputs
)
# print('outputs',type(outputs))
return
bev_embed
,
outputs
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/__init__.py
0 → 100644
View file @
df3c64a9
from
.epoch_based_runner
import
EpochBasedRunner_video
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py
0 → 100644
View file @
df3c64a9
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import
os.path
as
osp
import
torch
import
mmcv
from
mmcv.runner.base_runner
import
BaseRunner
from
mmcv.runner.epoch_based_runner
import
EpochBasedRunner
from
mmcv.runner.builder
import
RUNNERS
from
mmcv.runner.checkpoint
import
save_checkpoint
from
mmcv.runner.utils
import
get_host_info
from
pprint
import
pprint
from
mmcv.parallel.data_container
import
DataContainer
@
RUNNERS
.
register_module
()
class
EpochBasedRunner_video
(
EpochBasedRunner
):
'''
# basic logic
input_sequence = [a, b, c] # given a sequence of samples
prev_bev = None
for each in input_sequcene[:-1]
prev_bev = eval_model(each, prev_bev)) # inference only.
model(input_sequcene[-1], prev_bev) # train the last sample.
'''
def
__init__
(
self
,
model
,
eval_model
=
None
,
batch_processor
=
None
,
optimizer
=
None
,
work_dir
=
None
,
logger
=
None
,
meta
=
None
,
keys
=
[
'gt_bboxes_3d'
,
'gt_labels_3d'
,
'img'
],
max_iters
=
None
,
max_epochs
=
None
):
super
().
__init__
(
model
,
batch_processor
,
optimizer
,
work_dir
,
logger
,
meta
,
max_iters
,
max_epochs
)
keys
.
append
(
'img_metas'
)
self
.
keys
=
keys
self
.
eval_model
=
eval_model
self
.
eval_model
.
eval
()
def
run_iter
(
self
,
data_batch
,
train_mode
,
**
kwargs
):
if
self
.
batch_processor
is
not
None
:
assert
False
# outputs = self.batch_processor(
# self.model, data_batch, train_mode=train_mode, **kwargs)
elif
train_mode
:
num_samples
=
data_batch
[
'img'
].
data
[
0
].
size
(
1
)
data_list
=
[]
prev_bev
=
None
for
i
in
range
(
num_samples
):
data
=
{}
for
key
in
self
.
keys
:
if
key
not
in
[
'img_metas'
,
'img'
,
'points'
]:
data
[
key
]
=
data_batch
[
key
]
else
:
if
key
==
'img'
:
data
[
'img'
]
=
DataContainer
(
data
=
[
data_batch
[
'img'
].
data
[
0
][:,
i
]],
cpu_only
=
data_batch
[
'img'
].
cpu_only
,
stack
=
True
)
elif
key
==
'img_metas'
:
data
[
'img_metas'
]
=
DataContainer
(
data
=
[[
each
[
i
]
for
each
in
data_batch
[
'img_metas'
].
data
[
0
]]],
cpu_only
=
data_batch
[
'img_metas'
].
cpu_only
)
else
:
assert
False
data_list
.
append
(
data
)
with
torch
.
no_grad
():
for
i
in
range
(
num_samples
-
1
):
if
data_list
[
i
][
'img_metas'
].
data
[
0
][
0
][
'prev_bev_exists'
]:
data_list
[
i
][
'prev_bev'
]
=
DataContainer
(
data
=
[
prev_bev
],
cpu_only
=
False
)
prev_bev
=
self
.
eval_model
.
val_step
(
data_list
[
i
],
self
.
optimizer
,
**
kwargs
)
if
data_list
[
-
1
][
'img_metas'
].
data
[
0
][
0
][
'prev_bev_exists'
]:
data_list
[
-
1
][
'prev_bev'
]
=
DataContainer
(
data
=
[
prev_bev
],
cpu_only
=
False
)
outputs
=
self
.
model
.
train_step
(
data_list
[
-
1
],
self
.
optimizer
,
**
kwargs
)
else
:
assert
False
# outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if
not
isinstance
(
outputs
,
dict
):
raise
TypeError
(
'"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict'
)
if
'log_vars'
in
outputs
:
self
.
log_buffer
.
update
(
outputs
[
'log_vars'
],
outputs
[
'num_samples'
])
self
.
outputs
=
outputs
\ No newline at end of file
autonomous_driving/occupancy_prediction/projects/mmdet3d_plugin/core/bbox/assigners/__init__.py
0 → 100644
View file @
df3c64a9
from
.hungarian_assigner_3d
import
HungarianAssigner3D
__all__
=
[
'HungarianAssigner3D'
]
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment