Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
41b18fd8
Commit
41b18fd8
authored
Jan 06, 2025
by
zhe chen
Browse files
Use pre-commit to reformat code
Use pre-commit to reformat code
parent
ff20ea39
Changes
390
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1376 additions
and
1412 deletions
+1376
-1412
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/modules/dcnv3.py
...et3d/baseline/models/backbones/ops_dcnv3/modules/dcnv3.py
+12
-12
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/setup.py
...ugin/mmdet3d/baseline/models/backbones/ops_dcnv3/setup.py
+20
-25
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/src/cuda/dcnv3_cuda.cu
...aseline/models/backbones/ops_dcnv3/src/cuda/dcnv3_cuda.cu
+1
-1
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
...models/backbones/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
+1
-1
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/test.py
...lugin/mmdet3d/baseline/models/backbones/ops_dcnv3/test.py
+21
-24
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/detectors/__init__.py
...e-v2/plugin/mmdet3d/baseline/models/detectors/__init__.py
+0
-2
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/detectors/baseline.py
...e-v2/plugin/mmdet3d/baseline/models/detectors/baseline.py
+16
-16
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/detectors/road_bev.py
...e-v2/plugin/mmdet3d/baseline/models/detectors/road_bev.py
+17
-18
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/__init__.py
...nlane-v2/plugin/mmdet3d/baseline/models/heads/__init__.py
+0
-5
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/custom_detr_head.py
.../plugin/mmdet3d/baseline/models/heads/custom_detr_head.py
+38
-32
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/lc_deformable_detr_head.py
.../mmdet3d/baseline/models/heads/lc_deformable_detr_head.py
+15
-19
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/relationship_head.py
...plugin/mmdet3d/baseline/models/heads/relationship_head.py
+5
-9
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/te_deformable_detr_head.py
.../mmdet3d/baseline/models/heads/te_deformable_detr_head.py
+19
-20
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/topology_head.py
...-v2/plugin/mmdet3d/baseline/models/heads/topology_head.py
+9
-8
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/__init__.py
...ane-v2/plugin/mmdet3d/baseline/models/modules/__init__.py
+0
-6
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/bevformer_constructer.py
.../mmdet3d/baseline/models/modules/bevformer_constructer.py
+13
-16
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/custom_base_transformer_layer.py
.../baseline/models/modules/custom_base_transformer_layer.py
+259
-260
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/decoder.py
...lane-v2/plugin/mmdet3d/baseline/models/modules/decoder.py
+371
-378
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/encoder.py
...lane-v2/plugin/mmdet3d/baseline/models/modules/encoder.py
+395
-397
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/multi_scale_deformable_attn_function.py
...ne/models/modules/multi_scale_deformable_attn_function.py
+164
-163
No files found.
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/modules/dcnv3.py
View file @
41b18fd8
...
@@ -4,15 +4,15 @@
...
@@ -4,15 +4,15 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
print_function
from
__future__
import
division
import
warnings
import
warnings
import
torch
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.init
import
xavier_uniform_
,
constant_
from
torch
import
nn
from
torch.nn.init
import
constant_
,
xavier_uniform_
from
..functions
import
DCNv3Function
,
dcnv3_core_pytorch
from
..functions
import
DCNv3Function
,
dcnv3_core_pytorch
...
@@ -72,7 +72,7 @@ def build_act_layer(act_layer):
...
@@ -72,7 +72,7 @@ def build_act_layer(act_layer):
def
_is_power_of_2
(
n
):
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
raise
ValueError
(
"
invalid input for _is_power_of_2: {} (type: {})
"
.
format
(
n
,
type
(
n
)))
'
invalid input for _is_power_of_2: {} (type: {})
'
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
...
@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module):
...
@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module):
if
not
_is_power_of_2
(
_d_per_group
):
if
not
_is_power_of_2
(
_d_per_group
):
warnings
.
warn
(
warnings
.
warn
(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"
which is more efficient in our CUDA implementation.
"
)
'
which is more efficient in our CUDA implementation.
'
)
self
.
offset_scale
=
offset_scale
self
.
offset_scale
=
offset_scale
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -161,7 +161,7 @@ class DCNv3_pytorch(nn.Module):
...
@@ -161,7 +161,7 @@ class DCNv3_pytorch(nn.Module):
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
_reset_parameters
()
self
.
_reset_parameters
()
if
center_feature_scale
:
if
center_feature_scale
:
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
...
@@ -251,7 +251,7 @@ class DCNv3(nn.Module):
...
@@ -251,7 +251,7 @@ class DCNv3(nn.Module):
if
not
_is_power_of_2
(
_d_per_group
):
if
not
_is_power_of_2
(
_d_per_group
):
warnings
.
warn
(
warnings
.
warn
(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"
which is more efficient in our CUDA implementation.
"
)
'
which is more efficient in our CUDA implementation.
'
)
self
.
offset_scale
=
offset_scale
self
.
offset_scale
=
offset_scale
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -264,7 +264,7 @@ class DCNv3(nn.Module):
...
@@ -264,7 +264,7 @@ class DCNv3(nn.Module):
self
.
group_channels
=
channels
//
group
self
.
group_channels
=
channels
//
group
self
.
offset_scale
=
offset_scale
self
.
offset_scale
=
offset_scale
self
.
center_feature_scale
=
center_feature_scale
self
.
center_feature_scale
=
center_feature_scale
self
.
dw_conv
=
nn
.
Sequential
(
self
.
dw_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
nn
.
Conv2d
(
channels
,
channels
,
...
@@ -288,7 +288,7 @@ class DCNv3(nn.Module):
...
@@ -288,7 +288,7 @@ class DCNv3(nn.Module):
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
input_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
output_proj
=
nn
.
Linear
(
channels
,
channels
)
self
.
_reset_parameters
()
self
.
_reset_parameters
()
if
center_feature_scale
:
if
center_feature_scale
:
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
self
.
center_feature_scale_proj_weight
=
nn
.
Parameter
(
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
torch
.
zeros
((
group
,
channels
),
dtype
=
torch
.
float
))
...
@@ -332,7 +332,7 @@ class DCNv3(nn.Module):
...
@@ -332,7 +332,7 @@ class DCNv3(nn.Module):
self
.
group
,
self
.
group_channels
,
self
.
group
,
self
.
group_channels
,
self
.
offset_scale
,
self
.
offset_scale
,
256
)
256
)
if
self
.
center_feature_scale
:
if
self
.
center_feature_scale
:
center_feature_scale
=
self
.
center_feature_scale_module
(
center_feature_scale
=
self
.
center_feature_scale_module
(
x1
,
self
.
center_feature_scale_proj_weight
,
self
.
center_feature_scale_proj_bias
)
x1
,
self
.
center_feature_scale_proj_weight
,
self
.
center_feature_scale_proj_bias
)
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/setup.py
View file @
41b18fd8
...
@@ -4,39 +4,34 @@
...
@@ -4,39 +4,34 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
import
os
import
glob
import
glob
import
os
import
torch
import
torch
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
CUDA_HOME
,
CppExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
requirements
=
[
'torch'
,
'torchvision'
]
from
torch.utils.cpp_extension
import
CppExtension
from
torch.utils.cpp_extension
import
CUDAExtension
from
setuptools
import
find_packages
from
setuptools
import
setup
requirements
=
[
"torch"
,
"torchvision"
]
def
get_extensions
():
def
get_extensions
():
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
extensions_dir
=
os
.
path
.
join
(
this_dir
,
"
src
"
)
extensions_dir
=
os
.
path
.
join
(
this_dir
,
'
src
'
)
main_file
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"
*.cpp
"
))
main_file
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
'
*.cpp
'
))
source_cpu
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"
cpu
"
,
"
*.cpp
"
))
source_cpu
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
'
cpu
'
,
'
*.cpp
'
))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"
cuda
"
,
"
*.cu
"
))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
'
cuda
'
,
'
*.cu
'
))
sources
=
main_file
+
source_cpu
sources
=
main_file
+
source_cpu
extension
=
CppExtension
extension
=
CppExtension
extra_compile_args
=
{
"
cxx
"
:
[]}
extra_compile_args
=
{
'
cxx
'
:
[]}
define_macros
=
[]
define_macros
=
[]
if
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
:
if
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
:
extension
=
CUDAExtension
extension
=
CUDAExtension
sources
+=
source_cuda
sources
+=
source_cuda
define_macros
+=
[(
"
WITH_CUDA
"
,
None
)]
define_macros
+=
[(
'
WITH_CUDA
'
,
None
)]
extra_compile_args
[
"
nvcc
"
]
=
[
extra_compile_args
[
'
nvcc
'
]
=
[
# "-DCUDA_HAS_FP16=1",
# "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
...
@@ -49,7 +44,7 @@ def get_extensions():
...
@@ -49,7 +44,7 @@ def get_extensions():
include_dirs
=
[
extensions_dir
]
include_dirs
=
[
extensions_dir
]
ext_modules
=
[
ext_modules
=
[
extension
(
extension
(
"
DCNv3
"
,
'
DCNv3
'
,
sources
,
sources
,
include_dirs
=
include_dirs
,
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
define_macros
=
define_macros
,
...
@@ -60,16 +55,16 @@ def get_extensions():
...
@@ -60,16 +55,16 @@ def get_extensions():
setup
(
setup
(
name
=
"
DCNv3
"
,
name
=
'
DCNv3
'
,
version
=
"
1.0
"
,
version
=
'
1.0
'
,
author
=
"
InternImage
"
,
author
=
'
InternImage
'
,
url
=
"
https://github.com/OpenGVLab/InternImage
"
,
url
=
'
https://github.com/OpenGVLab/InternImage
'
,
description
=
description
=
"
PyTorch Wrapper for CUDA Functions of DCNv3
"
,
'
PyTorch Wrapper for CUDA Functions of DCNv3
'
,
packages
=
find_packages
(
exclude
=
(
packages
=
find_packages
(
exclude
=
(
"
configs
"
,
'
configs
'
,
"
tests
"
,
'
tests
'
,
)),
)),
ext_modules
=
get_extensions
(),
ext_modules
=
get_extensions
(),
cmdclass
=
{
"
build_ext
"
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
cmdclass
=
{
'
build_ext
'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
)
)
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/src/cuda/dcnv3_cuda.cu
View file @
41b18fd8
...
@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
...
@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
}
else
{
}
else
{
return
{
grad_input
,
grad_offset
,
grad_mask
};
return
{
grad_input
,
grad_offset
,
grad_mask
};
}
}
}
}
\ No newline at end of file
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh
View file @
41b18fd8
...
@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda(
...
@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda(
if
(
err
!=
cudaSuccess
)
{
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in dcnv3_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
printf
(
"error in dcnv3_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
}
}
\ No newline at end of file
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/backbones/ops_dcnv3/test.py
View file @
41b18fd8
...
@@ -4,16 +4,11 @@
...
@@ -4,16 +4,11 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
print_function
from
__future__
import
division
import
time
import
time
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
import
torch
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
H_in
,
W_in
=
8
,
8
H_in
,
W_in
=
8
,
8
...
@@ -32,11 +27,11 @@ torch.manual_seed(3)
...
@@ -32,11 +27,11 @@ torch.manual_seed(3)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_double
():
def
check_forward_equal_with_pytorch_double
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
output_pytorch
=
dcnv3_core_pytorch
(
output_pytorch
=
dcnv3_core_pytorch
(
input
.
double
(),
input
.
double
(),
...
@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double():
...
@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double():
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
output_pytorch
.
abs
()).
max
()
print
(
'>>> forward double'
)
print
(
'>>> forward double'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_double: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_float
():
def
check_forward_equal_with_pytorch_float
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
output_pytorch
=
dcnv3_core_pytorch
(
output_pytorch
=
dcnv3_core_pytorch
(
input
,
input
,
...
@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float():
...
@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float():
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
output_pytorch
.
abs
()).
max
()
print
(
'>>> forward float'
)
print
(
'>>> forward float'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
def
check_backward_equal_with_pytorch_double
(
channels
=
4
,
grad_input
=
True
,
grad_offset
=
True
,
grad_mask
=
True
):
def
check_backward_equal_with_pytorch_double
(
channels
=
4
,
grad_input
=
True
,
grad_offset
=
True
,
grad_mask
=
True
):
...
@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
...
@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
D
=
channels
D
=
channels
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
input0
.
requires_grad
=
grad_input
input0
.
requires_grad
=
grad_input
offset0
.
requires_grad
=
grad_offset
offset0
.
requires_grad
=
grad_offset
mask0
.
requires_grad
=
grad_mask
mask0
.
requires_grad
=
grad_mask
...
@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
...
@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
D
=
channels
D
=
channels
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
input0
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
offset0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
/=
mask0
.
sum
(
-
1
,
keepdim
=
True
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
mask0
=
mask0
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
input0
.
requires_grad
=
grad_input
input0
.
requires_grad
=
grad_input
offset0
.
requires_grad
=
grad_offset
offset0
.
requires_grad
=
grad_offset
mask0
.
requires_grad
=
grad_mask
mask0
.
requires_grad
=
grad_mask
...
@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128):
...
@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128):
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
0.01
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
offset
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
10
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
/=
mask
.
sum
(
-
1
,
keepdim
=
True
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
mask
=
mask
.
reshape
(
N
,
H_out
,
W_out
,
M
*
P
)
print
(
print
(
f
'>>> time cost: im2col_step
{
im2col_step
}
; input
{
input
.
shape
}
; points
{
P
}
'
)
f
'>>> time cost: im2col_step
{
im2col_step
}
; input
{
input
.
shape
}
; points
{
P
}
'
)
repeat
=
100
repeat
=
100
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/detectors/__init__.py
View file @
41b18fd8
from
.baseline
import
Baseline
from
.road_bev
import
ROAD_BEVFormer
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/detectors/baseline.py
View file @
41b18fd8
# ==============================================================================
# ==============================================================================
# Binaries and/or source for the following packages or projects
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# are presented under one or more of the following open source licenses:
# baseline.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# baseline.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
#
...
@@ -21,8 +21,7 @@
...
@@ -21,8 +21,7 @@
# ==============================================================================
# ==============================================================================
import
torch
import
torch
from
mmdet3d.models
import
DETECTORS
,
build_head
,
build_neck
from
mmdet3d.models
import
DETECTORS
,
build_neck
,
build_head
from
mmdet3d.models.detectors
import
MVXTwoStageDetector
from
mmdet3d.models.detectors
import
MVXTwoStageDetector
...
@@ -59,13 +58,14 @@ class Baseline(MVXTwoStageDetector):
...
@@ -59,13 +58,14 @@ class Baseline(MVXTwoStageDetector):
if
type
(
x
)
in
[
list
,
tuple
]:
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
x
=
x
[
0
]
# view transformation
# view transformation
bev_feat
=
self
.
img_view_transformer
(
bev_feat
=
self
.
img_view_transformer
(
x
,
x
,
torch
.
cat
([
torch
.
cat
([
torch
.
cat
([
torch
.
tensor
(
l
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
unsqueeze
(
0
)
for
l
in
img_metas
[
b
][
'lidar2img'
]],
dim
=
0
).
unsqueeze
(
0
)
torch
.
cat
([
torch
.
tensor
(
l
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
unsqueeze
(
0
)
for
l
in
for
b
in
range
(
B
)],
dim
=
0
),
img_metas
[
b
][
'lidar2img'
]],
dim
=
0
).
unsqueeze
(
0
)
for
b
in
range
(
B
)],
dim
=
0
),
(
img_metas
[
0
][
'img_shape'
][
0
][
0
],
img_metas
[
0
][
'img_shape'
][
0
][
1
]),
(
img_metas
[
0
][
'img_shape'
][
0
][
0
],
img_metas
[
0
][
'img_shape'
][
0
][
1
]),
)
)
_
,
output_dim
,
ouput_H
,
output_W
=
x
.
shape
_
,
output_dim
,
ouput_H
,
output_W
=
x
.
shape
...
@@ -76,10 +76,10 @@ class Baseline(MVXTwoStageDetector):
...
@@ -76,10 +76,10 @@ class Baseline(MVXTwoStageDetector):
lc_img_metas
=
[{
lc_img_metas
=
[{
'batch_input_shape'
:
(
bev_feat
.
shape
[
-
2
],
bev_feat
.
shape
[
-
1
]),
'batch_input_shape'
:
(
bev_feat
.
shape
[
-
2
],
bev_feat
.
shape
[
-
1
]),
'img_shape'
:
(
bev_feat
.
shape
[
-
2
],
bev_feat
.
shape
[
-
1
],
None
),
'img_shape'
:
(
bev_feat
.
shape
[
-
2
],
bev_feat
.
shape
[
-
1
],
None
),
'scale_factor'
:
None
,
# dummy
'scale_factor'
:
None
,
# dummy
}
for
_
in
range
(
B
)]
}
for
_
in
range
(
B
)]
all_lc_cls_scores_list
,
all_lc_preds_list
,
lc_outs_dec_list
=
self
.
lc_head
(
all_lc_cls_scores_list
,
all_lc_preds_list
,
lc_outs_dec_list
=
self
.
lc_head
(
[
bev_feat
],
[
bev_feat
],
lc_img_metas
,
lc_img_metas
,
)
)
...
@@ -91,7 +91,7 @@ class Baseline(MVXTwoStageDetector):
...
@@ -91,7 +91,7 @@ class Baseline(MVXTwoStageDetector):
'scale_factor'
:
img_metas
[
b
][
'scale_factor'
],
'scale_factor'
:
img_metas
[
b
][
'scale_factor'
],
}
for
b
in
range
(
B
)]
}
for
b
in
range
(
B
)]
all_te_cls_scores_list
,
all_te_preds_list
,
te_outs_dec_list
=
self
.
te_head
(
all_te_cls_scores_list
,
all_te_preds_list
,
te_outs_dec_list
=
self
.
te_head
(
[
pv_feat
],
[
pv_feat
],
te_img_metas
,
te_img_metas
,
)
)
...
@@ -149,7 +149,7 @@ class Baseline(MVXTwoStageDetector):
...
@@ -149,7 +149,7 @@ class Baseline(MVXTwoStageDetector):
})
})
# te
# te
te_loss_dict
,
te_assign_results
=
self
.
te_head
.
loss
(
te_loss_dict
,
te_assign_results
=
self
.
te_head
.
loss
(
outs
[
'all_te_cls_scores_list'
],
outs
[
'all_te_cls_scores_list'
],
outs
[
'all_te_preds_list'
],
outs
[
'all_te_preds_list'
],
...
@@ -186,20 +186,20 @@ class Baseline(MVXTwoStageDetector):
...
@@ -186,20 +186,20 @@ class Baseline(MVXTwoStageDetector):
})
})
return
losses
return
losses
def
forward_test
(
self
,
img
,
img_metas
,
**
kwargs
):
def
forward_test
(
self
,
img
,
img_metas
,
**
kwargs
):
outs
=
self
.
simple_forward
(
img
,
img_metas
)
outs
=
self
.
simple_forward
(
img
,
img_metas
)
pred_lc
=
self
.
lc_head
.
get_bboxes
(
pred_lc
=
self
.
lc_head
.
get_bboxes
(
outs
[
'all_lc_cls_scores_list'
],
outs
[
'all_lc_cls_scores_list'
],
outs
[
'all_lc_preds_list'
],
outs
[
'all_lc_preds_list'
],
outs
[
'lc_img_metas'
],
outs
[
'lc_img_metas'
],
)
)
pred_te
=
self
.
te_head
.
get_bboxes
(
pred_te
=
self
.
te_head
.
get_bboxes
(
outs
[
'all_te_cls_scores_list'
],
outs
[
'all_te_cls_scores_list'
],
outs
[
'all_te_preds_list'
],
outs
[
'all_te_preds_list'
],
outs
[
'te_img_metas'
],
outs
[
'te_img_metas'
],
rescale
=
True
,
rescale
=
True
,
)
)
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/detectors/road_bev.py
View file @
41b18fd8
...
@@ -3,17 +3,13 @@
...
@@ -3,17 +3,13 @@
# ---------------------------------------------
# ---------------------------------------------
# Modified by Tianyu Li
# Modified by Tianyu Li
# ---------------------------------------------
# ---------------------------------------------
import
time
import
copy
import
numpy
as
np
import
torch
from
mmcv.runner
import
force_fp32
,
auto_fp16
import
torch
from
mmdet.core
import
bbox2result
from
mmcv.runner
import
auto_fp16
from
mmdet.models
import
DETECTORS
from
mmdet.models.builder
import
build_head
from
mmdet3d.models.builder
import
build_neck
from
mmdet3d.models.builder
import
build_neck
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
mmdet.models
import
DETECTORS
from
mmdet.models.builder
import
build_head
@
DETECTORS
.
register_module
()
@
DETECTORS
.
register_module
()
...
@@ -79,7 +75,6 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -79,7 +75,6 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
'prev_angle'
:
0
,
'prev_angle'
:
0
,
}
}
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
def
extract_img_feat
(
self
,
img
,
img_metas
,
len_queue
=
None
):
"""Extract features of images."""
"""Extract features of images."""
B
=
img
.
size
(
0
)
B
=
img
.
size
(
0
)
...
@@ -108,7 +103,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -108,7 +103,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
for
img_feat
in
img_feats
:
for
img_feat
in
img_feats
:
BN
,
C
,
H
,
W
=
img_feat
.
size
()
BN
,
C
,
H
,
W
=
img_feat
.
size
()
if
len_queue
is
not
None
:
if
len_queue
is
not
None
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
img_feats_reshaped
.
append
(
img_feat
.
view
(
int
(
B
/
len_queue
),
len_queue
,
int
(
BN
/
B
),
C
,
H
,
W
))
else
:
else
:
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
img_feats_reshaped
.
append
(
img_feat
.
view
(
B
,
int
(
BN
/
B
),
C
,
H
,
W
))
return
img_feats_reshaped
return
img_feats_reshaped
...
@@ -118,7 +113,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -118,7 +113,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
"""Extract features from images and points."""
"""Extract features from images and points."""
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
img_feats
=
self
.
extract_img_feat
(
img
,
img_metas
,
len_queue
=
len_queue
)
return
img_feats
return
img_feats
def
forward_dummy
(
self
,
img
):
def
forward_dummy
(
self
,
img
):
...
@@ -139,7 +134,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -139,7 +134,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
return
self
.
forward_train
(
**
kwargs
)
return
self
.
forward_train
(
**
kwargs
)
else
:
else
:
return
self
.
forward_test
(
**
kwargs
)
return
self
.
forward_test
(
**
kwargs
)
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
def
obtain_history_bev
(
self
,
imgs_queue
,
img_metas_list
):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
"""
...
@@ -148,7 +143,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -148,7 +143,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
prev_bev
=
None
prev_bev
=
None
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
bs
,
len_queue
,
num_cams
,
C
,
H
,
W
=
imgs_queue
.
shape
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
imgs_queue
=
imgs_queue
.
reshape
(
bs
*
len_queue
,
num_cams
,
C
,
H
,
W
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
img_feats_list
=
self
.
extract_feat
(
img
=
imgs_queue
,
len_queue
=
len_queue
)
for
i
in
range
(
len_queue
):
for
i
in
range
(
len_queue
):
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
img_metas
=
[
each
[
i
]
for
each
in
img_metas_list
]
...
@@ -183,7 +178,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -183,7 +178,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
lane_feats
=
outs
[
'history_states'
]
lane_feats
=
outs
[
'history_states'
]
if
self
.
lclc_head
is
not
None
:
if
self
.
lclc_head
is
not
None
:
lclc_losses
=
self
.
lclc_head
.
forward_train
(
lane_feats
,
lane_assign_result
,
lane_feats
,
lane_assign_result
,
gt_topology_lclc
)
lclc_losses
=
self
.
lclc_head
.
forward_train
(
lane_feats
,
lane_assign_result
,
lane_feats
,
lane_assign_result
,
gt_topology_lclc
)
for
loss
in
lclc_losses
:
for
loss
in
lclc_losses
:
losses
[
'lclc_head.'
+
loss
]
=
lclc_losses
[
loss
]
losses
[
'lclc_head.'
+
loss
]
=
lclc_losses
[
loss
]
...
@@ -201,13 +197,15 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -201,13 +197,15 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
te_losses
=
{}
te_losses
=
{}
bbox_outs
=
self
.
bbox_head
(
front_view_img_feats
,
bbox_img_metas
)
bbox_outs
=
self
.
bbox_head
(
front_view_img_feats
,
bbox_img_metas
)
bbox_losses
,
te_assign_result
=
self
.
bbox_head
.
loss
(
bbox_outs
,
gt_te
,
gt_te_labels
,
bbox_img_metas
,
gt_bboxes_ignore
)
bbox_losses
,
te_assign_result
=
self
.
bbox_head
.
loss
(
bbox_outs
,
gt_te
,
gt_te_labels
,
bbox_img_metas
,
gt_bboxes_ignore
)
for
loss
in
bbox_losses
:
for
loss
in
bbox_losses
:
te_losses
[
'bbox_head.'
+
loss
]
=
bbox_losses
[
loss
]
te_losses
[
'bbox_head.'
+
loss
]
=
bbox_losses
[
loss
]
if
self
.
lcte_head
is
not
None
:
if
self
.
lcte_head
is
not
None
:
te_feats
=
bbox_outs
[
'history_states'
]
te_feats
=
bbox_outs
[
'history_states'
]
lcte_losses
=
self
.
lcte_head
.
forward_train
(
lane_feats
,
lane_assign_result
,
te_feats
,
te_assign_result
,
gt_topology_lcte
)
lcte_losses
=
self
.
lcte_head
.
forward_train
(
lane_feats
,
lane_assign_result
,
te_feats
,
te_assign_result
,
gt_topology_lcte
)
for
loss
in
lcte_losses
:
for
loss
in
lcte_losses
:
te_losses
[
'lcte_head.'
+
loss
]
=
lcte_losses
[
loss
]
te_losses
[
'lcte_head.'
+
loss
]
=
lcte_losses
[
loss
]
...
@@ -263,7 +261,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -263,7 +261,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
bbox_results
=
self
.
bbox_head
.
get_bboxes
(
bbox_outs
,
bbox_img_metas
,
rescale
=
rescale
)
bbox_results
=
self
.
bbox_head
.
get_bboxes
(
bbox_outs
,
bbox_img_metas
,
rescale
=
rescale
)
else
:
else
:
bbox_results
=
[
None
for
_
in
range
(
batchsize
)]
bbox_results
=
[
None
for
_
in
range
(
batchsize
)]
if
self
.
bbox_head
is
not
None
and
self
.
lcte_head
is
not
None
:
if
self
.
bbox_head
is
not
None
and
self
.
lcte_head
is
not
None
:
te_feats
=
bbox_outs
[
'history_states'
]
te_feats
=
bbox_outs
[
'history_states'
]
lcte_results
=
self
.
lcte_head
.
get_relationship
(
lane_feats
,
te_feats
)
lcte_results
=
self
.
lcte_head
.
get_relationship
(
lane_feats
,
te_feats
)
...
@@ -280,7 +278,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
...
@@ -280,7 +278,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
results_list
=
[
dict
()
for
i
in
range
(
len
(
img_metas
))]
results_list
=
[
dict
()
for
i
in
range
(
len
(
img_metas
))]
new_prev_bev
,
bbox_results
,
lane_results
,
lclc_results
,
lcte_results
=
self
.
simple_test_pts
(
new_prev_bev
,
bbox_results
,
lane_results
,
lclc_results
,
lcte_results
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
img
,
prev_bev
,
rescale
=
rescale
)
img_feats
,
img_metas
,
img
,
prev_bev
,
rescale
=
rescale
)
for
result_dict
,
bbox
,
lane
,
lclc
,
lcte
in
zip
(
results_list
,
bbox_results
,
lane_results
,
lclc_results
,
lcte_results
):
for
result_dict
,
bbox
,
lane
,
lclc
,
lcte
in
zip
(
results_list
,
bbox_results
,
lane_results
,
lclc_results
,
lcte_results
):
result_dict
[
'pred_te'
]
=
bbox
result_dict
[
'pred_te'
]
=
bbox
result_dict
[
'pred_lc'
]
=
lane
result_dict
[
'pred_lc'
]
=
lane
result_dict
[
'pred_topology_lclc'
]
=
lclc
result_dict
[
'pred_topology_lclc'
]
=
lclc
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/__init__.py
View file @
41b18fd8
from
.custom_detr_head
import
*
from
.topology_head
import
*
from
.lc_deformable_detr_head
import
LCDeformableDETRHead
from
.te_deformable_detr_head
import
TEDeformableDETRHead
from
.relationship_head
import
RelationshipHead
\ No newline at end of file
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/custom_detr_head.py
View file @
41b18fd8
# ==============================================================================
# ==============================================================================
# Binaries and/or source for the following packages or projects
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
#
...
@@ -23,17 +23,17 @@
...
@@ -23,17 +23,17 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Linear
from
mmcv.cnn
import
Linear
from
mmdet.core
import
bbox_cxcywh_to_xyxy
,
bbox_xyxy_to_cxcywh
,
multi_apply
,
reduce_mean
from
mmdet.core
import
(
bbox_cxcywh_to_xyxy
,
bbox_xyxy_to_cxcywh
,
multi_apply
,
reduce_mean
)
from
mmdet.models
import
HEADS
,
DETRHead
from
mmdet.models
import
HEADS
,
DETRHead
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
CustomDETRHead
(
DETRHead
):
class
CustomDETRHead
(
DETRHead
):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
,
in_channels
,
in_channels
,
num_query
,
num_query
,
object_type
,
object_type
,
...
@@ -46,7 +46,7 @@ class CustomDETRHead(DETRHead):
...
@@ -46,7 +46,7 @@ class CustomDETRHead(DETRHead):
ffn_dropout
=
0.1
,
ffn_dropout
=
0.1
,
**
kwargs
):
**
kwargs
):
self
.
object_type
=
object_type
self
.
object_type
=
object_type
if
self
.
object_type
==
'lane'
:
if
self
.
object_type
==
'lane'
:
self
.
num_reg_dim
=
num_reg_dim
self
.
num_reg_dim
=
num_reg_dim
assert
self
.
num_reg_dim
%
3
==
0
assert
self
.
num_reg_dim
%
3
==
0
...
@@ -57,7 +57,7 @@ class CustomDETRHead(DETRHead):
...
@@ -57,7 +57,7 @@ class CustomDETRHead(DETRHead):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
transformer
=
dict
(
transformer
=
dict
(
type
=
'Transformer'
,
type
=
'Transformer'
,
encoder
=
dict
(
encoder
=
dict
(
type
=
'DetrTransformerEncoder'
,
type
=
'DetrTransformerEncoder'
,
...
@@ -106,13 +106,13 @@ class CustomDETRHead(DETRHead):
...
@@ -106,13 +106,13 @@ class CustomDETRHead(DETRHead):
operation_order
=
(
'self_attn'
,
'norm'
,
'cross_attn'
,
'norm'
,
operation_order
=
(
'self_attn'
,
'norm'
,
'cross_attn'
,
'norm'
,
'ffn'
,
'norm'
)),
'ffn'
,
'norm'
)),
))
))
positional_encoding
=
dict
(
positional_encoding
=
dict
(
type
=
'SinePositionalEncoding'
,
num_feats
=
embed_dims
//
2
,
normalize
=
True
)
type
=
'SinePositionalEncoding'
,
num_feats
=
embed_dims
//
2
,
normalize
=
True
)
super
().
__init__
(
super
().
__init__
(
num_classes
=
num_classes
,
num_classes
=
num_classes
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
num_query
=
num_query
,
num_query
=
num_query
,
transformer
=
transformer
,
transformer
=
transformer
,
positional_encoding
=
positional_encoding
,
positional_encoding
=
positional_encoding
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -135,7 +135,7 @@ class CustomDETRHead(DETRHead):
...
@@ -135,7 +135,7 @@ class CustomDETRHead(DETRHead):
for
img_id
in
range
(
batch_size
):
for
img_id
in
range
(
batch_size
):
img_h
,
img_w
,
_
=
img_metas
[
img_id
][
'img_shape'
]
img_h
,
img_w
,
_
=
img_metas
[
img_id
][
'img_shape'
]
masks
[
img_id
,
:
img_h
,
:
img_w
]
=
0
masks
[
img_id
,
:
img_h
,
:
img_w
]
=
0
x
=
self
.
input_proj
(
x
)
x
=
self
.
input_proj
(
x
)
# interpolate masks to have the same spatial shape with x
# interpolate masks to have the same spatial shape with x
masks
=
F
.
interpolate
(
masks
=
F
.
interpolate
(
...
@@ -221,7 +221,7 @@ class CustomDETRHead(DETRHead):
...
@@ -221,7 +221,7 @@ class CustomDETRHead(DETRHead):
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
# construct weighted avg_factor to match with the official DETR repo
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_avg_factor
=
reduce_mean
(
cls_scores
.
new_tensor
([
cls_avg_factor
]))
cls_scores
.
new_tensor
([
cls_avg_factor
]))
...
@@ -244,8 +244,8 @@ class CustomDETRHead(DETRHead):
...
@@ -244,8 +244,8 @@ class CustomDETRHead(DETRHead):
for
img_meta
,
bbox_pred
in
zip
(
img_metas
,
bbox_preds
):
for
img_meta
,
bbox_pred
in
zip
(
img_metas
,
bbox_preds
):
img_h
,
img_w
,
_
=
img_meta
[
'img_shape'
]
img_h
,
img_w
,
_
=
img_meta
[
'img_shape'
]
factor
=
bbox_pred
.
new_tensor
([
img_w
,
img_h
,
img_w
,
factor
=
bbox_pred
.
new_tensor
([
img_w
,
img_h
,
img_w
,
img_h
]).
unsqueeze
(
0
).
repeat
(
img_h
]).
unsqueeze
(
0
).
repeat
(
bbox_pred
.
size
(
0
),
1
)
bbox_pred
.
size
(
0
),
1
)
factors
.
append
(
factor
)
factors
.
append
(
factor
)
factors
=
torch
.
cat
(
factors
,
0
)
factors
=
torch
.
cat
(
factors
,
0
)
...
@@ -282,8 +282,8 @@ class CustomDETRHead(DETRHead):
...
@@ -282,8 +282,8 @@ class CustomDETRHead(DETRHead):
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_assigned_gt_inds_list
)
=
multi_apply
(
bbox_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_assigned_gt_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
gt_bboxes_ignore_list
)
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
assign_result
=
dict
(
assign_result
=
dict
(
...
@@ -312,7 +312,7 @@ class CustomDETRHead(DETRHead):
...
@@ -312,7 +312,7 @@ class CustomDETRHead(DETRHead):
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
# label targets
# label targets
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,
),
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,),
self
.
num_classes
,
self
.
num_classes
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
...
@@ -327,9 +327,12 @@ class CustomDETRHead(DETRHead):
...
@@ -327,9 +327,12 @@ class CustomDETRHead(DETRHead):
pos_gt_bboxes
=
sampling_result
.
pos_gt_bboxes
pos_gt_bboxes
=
sampling_result
.
pos_gt_bboxes
pos_gt_bboxes_normalized
=
torch
.
zeros_like
(
pos_gt_bboxes
)
pos_gt_bboxes_normalized
=
torch
.
zeros_like
(
pos_gt_bboxes
)
for
p
in
range
(
self
.
num_reg_dim
//
3
):
for
p
in
range
(
self
.
num_reg_dim
//
3
):
pos_gt_bboxes_normalized
[...,
3
*
p
]
=
(
pos_gt_bboxes
[...,
3
*
p
]
-
self
.
bev_range
[
0
])
/
(
self
.
bev_range
[
3
]
-
self
.
bev_range
[
0
])
pos_gt_bboxes_normalized
[...,
3
*
p
]
=
(
pos_gt_bboxes
[...,
3
*
p
]
-
self
.
bev_range
[
0
])
/
(
pos_gt_bboxes_normalized
[...,
3
*
p
+
1
]
=
(
pos_gt_bboxes
[...,
3
*
p
+
1
]
-
self
.
bev_range
[
1
])
/
(
self
.
bev_range
[
4
]
-
self
.
bev_range
[
1
])
self
.
bev_range
[
3
]
-
self
.
bev_range
[
0
])
pos_gt_bboxes_normalized
[...,
3
*
p
+
2
]
=
(
pos_gt_bboxes
[...,
3
*
p
+
2
]
-
self
.
bev_range
[
2
])
/
(
self
.
bev_range
[
5
]
-
self
.
bev_range
[
2
])
pos_gt_bboxes_normalized
[...,
3
*
p
+
1
]
=
(
pos_gt_bboxes
[...,
3
*
p
+
1
]
-
self
.
bev_range
[
1
])
/
(
self
.
bev_range
[
4
]
-
self
.
bev_range
[
1
])
pos_gt_bboxes_normalized
[...,
3
*
p
+
2
]
=
(
pos_gt_bboxes
[...,
3
*
p
+
2
]
-
self
.
bev_range
[
2
])
/
(
self
.
bev_range
[
5
]
-
self
.
bev_range
[
2
])
pos_gt_bboxes_targets
=
pos_gt_bboxes_normalized
pos_gt_bboxes_targets
=
pos_gt_bboxes_normalized
else
:
else
:
img_h
,
img_w
,
_
=
img_meta
[
'img_shape'
]
img_h
,
img_w
,
_
=
img_meta
[
'img_shape'
]
...
@@ -338,10 +341,10 @@ class CustomDETRHead(DETRHead):
...
@@ -338,10 +341,10 @@ class CustomDETRHead(DETRHead):
# Thus the learning target should be normalized by the image size, also
# Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
factor
=
bbox_pred
.
new_tensor
([
img_w
,
img_h
,
img_w
,
factor
=
bbox_pred
.
new_tensor
([
img_w
,
img_h
,
img_w
,
img_h
]).
unsqueeze
(
0
)
img_h
]).
unsqueeze
(
0
)
pos_gt_bboxes_normalized
=
sampling_result
.
pos_gt_bboxes
/
factor
pos_gt_bboxes_normalized
=
sampling_result
.
pos_gt_bboxes
/
factor
pos_gt_bboxes_targets
=
bbox_xyxy_to_cxcywh
(
pos_gt_bboxes_normalized
)
pos_gt_bboxes_targets
=
bbox_xyxy_to_cxcywh
(
pos_gt_bboxes_normalized
)
bbox_targets
[
pos_inds
]
=
pos_gt_bboxes_targets
bbox_targets
[
pos_inds
]
=
pos_gt_bboxes_targets
return
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pos_inds
,
return
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
pos_inds
,
neg_inds
,
pos_assigned_gt_inds
)
neg_inds
,
pos_assigned_gt_inds
)
...
@@ -360,12 +363,15 @@ class CustomDETRHead(DETRHead):
...
@@ -360,12 +363,15 @@ class CustomDETRHead(DETRHead):
det_bboxes
=
bbox_pred
det_bboxes
=
bbox_pred
for
p
in
range
(
self
.
num_reg_dim
//
3
):
for
p
in
range
(
self
.
num_reg_dim
//
3
):
det_bboxes
[...,
3
*
p
]
=
det_bboxes
[...,
3
*
p
]
*
(
self
.
bev_range
[
3
]
-
self
.
bev_range
[
0
])
+
self
.
bev_range
[
0
]
det_bboxes
[...,
3
*
p
]
=
det_bboxes
[...,
3
*
p
]
*
(
self
.
bev_range
[
3
]
-
self
.
bev_range
[
0
])
+
\
det_bboxes
[...,
3
*
p
+
1
]
=
det_bboxes
[...,
3
*
p
+
1
]
*
(
self
.
bev_range
[
4
]
-
self
.
bev_range
[
1
])
+
self
.
bev_range
[
1
]
self
.
bev_range
[
0
]
det_bboxes
[...,
3
*
p
+
2
]
=
det_bboxes
[...,
3
*
p
+
2
]
*
(
self
.
bev_range
[
5
]
-
self
.
bev_range
[
2
])
+
self
.
bev_range
[
2
]
det_bboxes
[...,
3
*
p
+
1
]
=
det_bboxes
[...,
3
*
p
+
1
]
*
(
self
.
bev_range
[
4
]
-
self
.
bev_range
[
1
])
+
\
det_bboxes
[...,
3
*
p
].
clamp_
(
min
=
self
.
bev_range
[
0
],
max
=
self
.
bev_range
[
3
])
self
.
bev_range
[
1
]
det_bboxes
[...,
3
*
p
+
1
].
clamp_
(
min
=
self
.
bev_range
[
1
],
max
=
self
.
bev_range
[
4
])
det_bboxes
[...,
3
*
p
+
2
]
=
det_bboxes
[...,
3
*
p
+
2
]
*
(
self
.
bev_range
[
5
]
-
self
.
bev_range
[
2
])
+
\
det_bboxes
[...,
3
*
p
+
2
].
clamp_
(
min
=
self
.
bev_range
[
2
],
max
=
self
.
bev_range
[
5
])
self
.
bev_range
[
2
]
det_bboxes
[...,
3
*
p
].
clamp_
(
min
=
self
.
bev_range
[
0
],
max
=
self
.
bev_range
[
3
])
det_bboxes
[...,
3
*
p
+
1
].
clamp_
(
min
=
self
.
bev_range
[
1
],
max
=
self
.
bev_range
[
4
])
det_bboxes
[...,
3
*
p
+
2
].
clamp_
(
min
=
self
.
bev_range
[
2
],
max
=
self
.
bev_range
[
5
])
else
:
else
:
# exclude background
# exclude background
if
self
.
loss_cls
.
use_sigmoid
:
if
self
.
loss_cls
.
use_sigmoid
:
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/lc_deformable_detr_head.py
View file @
41b18fd8
# ==============================================================================
# ==============================================================================
# Binaries and/or source for the following packages or projects
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
#
...
@@ -22,12 +22,8 @@
...
@@ -22,12 +22,8 @@
import
copy
import
copy
import
numpy
as
np
import
cv2
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
mmcv
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.runner
import
auto_fp16
,
force_fp32
from
mmcv.runner
import
auto_fp16
,
force_fp32
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
...
@@ -79,15 +75,15 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -79,15 +75,15 @@ class LCDeformableDETRHead(AnchorFreeHead):
self
.
bg_cls_weight
=
0
self
.
bg_cls_weight
=
0
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
self
.
sync_cls_avg_factor
=
sync_cls_avg_factor
if
train_cfg
:
if
train_cfg
:
assert
'assigner'
in
train_cfg
,
'assigner should be provided '
\
assert
'assigner'
in
train_cfg
,
'assigner should be provided '
\
'when train_cfg is set.'
'when train_cfg is set.'
assigner
=
train_cfg
[
'assigner'
]
assigner
=
train_cfg
[
'assigner'
]
assert
loss_cls
[
'loss_weight'
]
==
assigner
[
'cls_cost'
][
'weight'
],
\
assert
loss_cls
[
'loss_weight'
]
==
assigner
[
'cls_cost'
][
'weight'
],
\
'The classification weight for loss and matcher should be'
\
'The classification weight for loss and matcher should be'
\
'exactly the same.'
'exactly the same.'
assert
loss_bbox
[
'loss_weight'
]
==
assigner
[
'reg_cost'
][
assert
loss_bbox
[
'loss_weight'
]
==
assigner
[
'reg_cost'
][
'weight'
],
'The regression L1 weight for loss and matcher '
\
'weight'
],
'The regression L1 weight for loss and matcher '
\
'should be exactly the same.'
'should be exactly the same.'
assert
loss_iou
[
'loss_weight'
]
==
assigner
[
'iou_cost'
][
'weight'
],
\
assert
loss_iou
[
'loss_weight'
]
==
assigner
[
'iou_cost'
][
'weight'
],
\
'The regression iou weight for loss and matcher should be'
\
'The regression iou weight for loss and matcher should be'
\
'exactly the same.'
'exactly the same.'
...
@@ -195,7 +191,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -195,7 +191,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
network, each is a 5D-tensor with shape
network, each is a 5D-tensor with shape
(B, N, C, H, W).
(B, N, C, H, W).
prev_bev: previous bev featues
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
only_bev: only compute BEV features with encoder.
Returns:
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
\
all_cls_scores (Tensor): Outputs from the classification head,
\
shape [nb_dec, bs, num_query, cls_out_channels]. Note
\
shape [nb_dec, bs, num_query, cls_out_channels]. Note
\
...
@@ -232,12 +228,12 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -232,12 +228,12 @@ class LCDeformableDETRHead(AnchorFreeHead):
assert
reference
.
shape
[
-
1
]
==
3
assert
reference
.
shape
[
-
1
]
==
3
for
p
in
range
(
self
.
code_size
//
3
):
for
p
in
range
(
self
.
code_size
//
3
):
tmp
[...,
3
*
p
:
3
*
p
+
3
]
=
tmp
[...,
3
*
p
:
3
*
p
+
3
]
+
reference
tmp
[...,
3
*
p
:
3
*
p
+
3
]
=
tmp
[...,
3
*
p
:
3
*
p
+
3
]
+
reference
tmp
[...,
3
*
p
:
3
*
p
+
3
]
=
tmp
[...,
3
*
p
:
3
*
p
+
3
].
sigmoid
()
tmp
[...,
3
*
p
:
3
*
p
+
3
]
=
tmp
[...,
3
*
p
:
3
*
p
+
3
].
sigmoid
()
tmp
[...,
3
*
p
]
=
tmp
[...,
3
*
p
]
*
(
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
])
+
self
.
pc_range
[
0
]
tmp
[...,
3
*
p
]
=
tmp
[...,
3
*
p
]
*
(
self
.
pc_range
[
3
]
-
self
.
pc_range
[
0
])
+
self
.
pc_range
[
0
]
tmp
[...,
3
*
p
+
1
]
=
tmp
[...,
3
*
p
+
1
]
*
(
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
])
+
self
.
pc_range
[
1
]
tmp
[...,
3
*
p
+
1
]
=
tmp
[...,
3
*
p
+
1
]
*
(
self
.
pc_range
[
4
]
-
self
.
pc_range
[
1
])
+
self
.
pc_range
[
1
]
tmp
[...,
3
*
p
+
2
]
=
tmp
[...,
3
*
p
+
2
]
*
(
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
])
+
self
.
pc_range
[
2
]
tmp
[...,
3
*
p
+
2
]
=
tmp
[...,
3
*
p
+
2
]
*
(
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
])
+
self
.
pc_range
[
2
]
outputs_coord
=
tmp
outputs_coord
=
tmp
outputs_classes
.
append
(
outputs_class
)
outputs_classes
.
append
(
outputs_class
)
outputs_coords
.
append
(
outputs_coord
)
outputs_coords
.
append
(
outputs_coord
)
...
@@ -293,7 +289,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -293,7 +289,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
lanes_pred
,
sampling_result
=
self
.
sampler
.
sample
(
assign_result
,
lanes_pred
,
gt_lanes
)
gt_lanes
)
pos_inds
=
sampling_result
.
pos_inds
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
neg_inds
=
sampling_result
.
neg_inds
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
...
@@ -415,7 +411,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -415,7 +411,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
# construct weighted avg_factor to match with the official DETR repo
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_avg_factor
=
reduce_mean
(
cls_scores
.
new_tensor
([
cls_avg_factor
]))
cls_scores
.
new_tensor
([
cls_avg_factor
]))
...
@@ -436,7 +432,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -436,7 +432,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
bbox_weights
=
bbox_weights
*
self
.
code_weights
bbox_weights
=
bbox_weights
*
self
.
code_weights
loss_bbox
=
self
.
loss_bbox
(
loss_bbox
=
self
.
loss_bbox
(
lanes_preds
[
isnotnan
,
:
self
.
code_size
],
lanes_preds
[
isnotnan
,
:
self
.
code_size
],
bbox_targets
[
isnotnan
,
:
self
.
code_size
],
bbox_targets
[
isnotnan
,
:
self
.
code_size
],
bbox_weights
[
isnotnan
,
:
self
.
code_size
],
bbox_weights
[
isnotnan
,
:
self
.
code_size
],
avg_factor
=
num_total_pos
)
avg_factor
=
num_total_pos
)
...
@@ -544,7 +540,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
...
@@ -544,7 +540,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
cls_scores
=
all_cls_scores
[
i
].
sigmoid
()
cls_scores
=
all_cls_scores
[
i
].
sigmoid
()
predictions_list
.
append
([
predictions_list
.
append
([
all_lanes_preds
[
i
].
detach
().
cpu
().
numpy
(),
all_lanes_preds
[
i
].
detach
().
cpu
().
numpy
(),
cls_scores
.
detach
().
cpu
().
numpy
()])
cls_scores
.
detach
().
cpu
().
numpy
()])
return
predictions_list
return
predictions_list
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/relationship_head.py
View file @
41b18fd8
import
copy
import
copy
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
cv2
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
mmcv
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
build_activation_layer
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
from
mmcv.runner
import
auto_fp16
,
force_fp32
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmdet.core
import
build_assigner
,
build_sampler
,
multi_apply
,
reduce_mean
from
mmdet.core
import
build_assigner
,
build_sampler
,
multi_apply
,
reduce_mean
from
mmdet.models.builder
import
HEADS
,
build_loss
from
mmdet.models.builder
import
HEADS
,
build_loss
from
mmdet.models.dense_heads
import
AnchorFreeHead
from
mmdet.models.dense_heads
import
AnchorFreeHead
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils
import
build_transformer
from
mmdet.models.utils.transformer
import
inverse_sigmoid
class
MLP
(
nn
.
Module
):
class
MLP
(
nn
.
Module
):
...
@@ -38,10 +34,10 @@ class RelationshipHead(nn.Module):
...
@@ -38,10 +34,10 @@ class RelationshipHead(nn.Module):
in_channels_o2
=
None
,
in_channels_o2
=
None
,
shared_param
=
True
,
shared_param
=
True
,
loss_rel
=
dict
(
loss_rel
=
dict
(
type
=
'FocalLoss'
,
type
=
'FocalLoss'
,
use_sigmoid
=
True
,
use_sigmoid
=
True
,
gamma
=
2.0
,
gamma
=
2.0
,
alpha
=
0.25
)):
alpha
=
0.25
)):
super
().
__init__
()
super
().
__init__
()
self
.
MLP_o1
=
MLP
(
in_channels_o1
,
in_channels_o1
,
128
,
3
)
self
.
MLP_o1
=
MLP
(
in_channels_o1
,
in_channels_o1
,
128
,
3
)
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/te_deformable_detr_head.py
View file @
41b18fd8
# ==============================================================================
# ==============================================================================
# Binaries and/or source for the following packages or projects
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
#
...
@@ -27,12 +27,11 @@ import torch.nn as nn
...
@@ -27,12 +27,11 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
constant_init
from
mmcv.cnn
import
Linear
,
bias_init_with_prob
,
constant_init
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
force_fp32
from
mmdet.core
import
(
bbox_cxcywh_to_xyxy
,
bbox_xyxy_to_cxcywh
,
multi_apply
,
from
mmdet.core
import
(
bbox_cxcywh_to_xyxy
,
bbox_xyxy_to_cxcywh
,
reduce_mean
)
multi_apply
,
reduce_mean
)
from
mmdet.models
import
HEADS
from
mmdet.models.utils.transformer
import
inverse_sigmoid
from
mmdet.models
import
HEADS
,
build_loss
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet.models.dense_heads
import
DETRHead
from
mmdet.models.utils.transformer
import
inverse_sigmoid
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
...
@@ -163,14 +162,14 @@ class TEDeformableDETRHead(DETRHead):
...
@@ -163,14 +162,14 @@ class TEDeformableDETRHead(DETRHead):
if
not
self
.
as_two_stage
:
if
not
self
.
as_two_stage
:
query_embeds
=
self
.
query_embedding
.
weight
query_embeds
=
self
.
query_embedding
.
weight
hs
,
init_reference
,
inter_references
,
\
hs
,
init_reference
,
inter_references
,
\
enc_outputs_class
,
enc_outputs_coord
=
self
.
transformer
(
enc_outputs_class
,
enc_outputs_coord
=
self
.
transformer
(
mlvl_feats
,
mlvl_feats
,
mlvl_masks
,
mlvl_masks
,
query_embeds
,
query_embeds
,
mlvl_positional_encodings
,
mlvl_positional_encodings
,
reg_branches
=
self
.
reg_branches
if
self
.
with_box_refine
else
None
,
# noqa:E501
reg_branches
=
self
.
reg_branches
if
self
.
with_box_refine
else
None
,
# noqa:E501
cls_branches
=
self
.
cls_branches
if
self
.
as_two_stage
else
None
# noqa:E501
cls_branches
=
self
.
cls_branches
if
self
.
as_two_stage
else
None
# noqa:E501
)
)
hs
=
hs
.
permute
(
0
,
2
,
1
,
3
)
hs
=
hs
.
permute
(
0
,
2
,
1
,
3
)
outputs_classes
=
[]
outputs_classes
=
[]
outputs_coords
=
[]
outputs_coords
=
[]
...
@@ -199,7 +198,7 @@ class TEDeformableDETRHead(DETRHead):
...
@@ -199,7 +198,7 @@ class TEDeformableDETRHead(DETRHead):
'all_cls_scores'
:
outputs_classes
,
'all_cls_scores'
:
outputs_classes
,
'all_bbox_preds'
:
outputs_coords
,
'all_bbox_preds'
:
outputs_coords
,
'enc_cls_scores'
:
enc_outputs_class
if
self
.
as_two_stage
else
None
,
'enc_cls_scores'
:
enc_outputs_class
if
self
.
as_two_stage
else
None
,
'enc_bbox_preds'
:
enc_outputs_coord
.
sigmoid
()
if
self
.
as_two_stage
else
None
,
'enc_bbox_preds'
:
enc_outputs_coord
.
sigmoid
()
if
self
.
as_two_stage
else
None
,
'history_states'
:
hs
'history_states'
:
hs
}
}
...
@@ -336,7 +335,7 @@ class TEDeformableDETRHead(DETRHead):
...
@@ -336,7 +335,7 @@ class TEDeformableDETRHead(DETRHead):
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
cls_scores
=
cls_scores
.
reshape
(
-
1
,
self
.
cls_out_channels
)
# construct weighted avg_factor to match with the official DETR repo
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor
=
num_total_pos
*
1.0
+
\
cls_avg_factor
=
num_total_pos
*
1.0
+
\
num_total_neg
*
self
.
bg_cls_weight
num_total_neg
*
self
.
bg_cls_weight
if
self
.
sync_cls_avg_factor
:
if
self
.
sync_cls_avg_factor
:
cls_avg_factor
=
reduce_mean
(
cls_avg_factor
=
reduce_mean
(
cls_scores
.
new_tensor
([
cls_avg_factor
]))
cls_scores
.
new_tensor
([
cls_avg_factor
]))
...
@@ -356,7 +355,7 @@ class TEDeformableDETRHead(DETRHead):
...
@@ -356,7 +355,7 @@ class TEDeformableDETRHead(DETRHead):
img_h
,
img_w
,
_
=
img_meta
[
'img_shape'
]
img_h
,
img_w
,
_
=
img_meta
[
'img_shape'
]
factor
=
bbox_pred
.
new_tensor
([
img_w
,
img_h
,
img_w
,
factor
=
bbox_pred
.
new_tensor
([
img_w
,
img_h
,
img_w
,
img_h
]).
unsqueeze
(
0
).
repeat
(
img_h
]).
unsqueeze
(
0
).
repeat
(
bbox_pred
.
size
(
0
),
1
)
bbox_pred
.
size
(
0
),
1
)
factors
.
append
(
factor
)
factors
.
append
(
factor
)
factors
=
torch
.
cat
(
factors
,
0
)
factors
=
torch
.
cat
(
factors
,
0
)
...
@@ -426,8 +425,8 @@ class TEDeformableDETRHead(DETRHead):
...
@@ -426,8 +425,8 @@ class TEDeformableDETRHead(DETRHead):
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_assigned_gt_inds_list
)
=
multi_apply
(
bbox_weights_list
,
pos_inds_list
,
neg_inds_list
,
pos_assigned_gt_inds_list
)
=
multi_apply
(
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
self
.
_get_target_single
,
cls_scores_list
,
bbox_preds_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
gt_bboxes_ignore_list
)
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
gt_bboxes_ignore_list
)
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_pos
=
sum
((
inds
.
numel
()
for
inds
in
pos_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
num_total_neg
=
sum
((
inds
.
numel
()
for
inds
in
neg_inds_list
))
assign_result
=
dict
(
assign_result
=
dict
(
...
@@ -484,7 +483,7 @@ class TEDeformableDETRHead(DETRHead):
...
@@ -484,7 +483,7 @@ class TEDeformableDETRHead(DETRHead):
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
pos_assigned_gt_inds
=
sampling_result
.
pos_assigned_gt_inds
# label targets
# label targets
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,
),
labels
=
gt_bboxes
.
new_full
((
num_bboxes
,),
self
.
num_classes
,
self
.
num_classes
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/heads/topology_head.py
View file @
41b18fd8
# ==============================================================================
# ==============================================================================
# Binaries and/or source for the following packages or projects
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# are presented under one or more of the following open source licenses:
# topology_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# topology_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
#
...
@@ -23,7 +23,6 @@
...
@@ -23,7 +23,6 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
from
mmdet.models
import
HEADS
,
build_loss
from
mmdet.models
import
HEADS
,
build_loss
...
@@ -41,13 +40,14 @@ class MLP(nn.Module):
...
@@ -41,13 +40,14 @@ class MLP(nn.Module):
x
=
F
.
relu
(
layer
(
x
))
if
i
<
self
.
num_layers
-
1
else
layer
(
x
)
x
=
F
.
relu
(
layer
(
x
))
if
i
<
self
.
num_layers
-
1
else
layer
(
x
)
return
x
return
x
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
TopologyHead
(
BaseModule
):
class
TopologyHead
(
BaseModule
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
hidden_channels
,
hidden_channels
,
out_channels
,
out_channels
,
num_layers
,
num_layers
,
loss_cls
):
loss_cls
):
...
@@ -94,10 +94,11 @@ class TopologyHead(BaseModule):
...
@@ -94,10 +94,11 @@ class TopologyHead(BaseModule):
target
=
pred_adj
.
new_zeros
(
pred_adj
[
b
].
shape
[:
-
1
])
target
=
pred_adj
.
new_zeros
(
pred_adj
[
b
].
shape
[:
-
1
])
rs
=
row_assign_result
[
'pos_inds'
][
b
].
unsqueeze
(
-
1
).
repeat
(
1
,
column_assign_result
[
'pos_inds'
][
b
].
shape
[
0
])
rs
=
row_assign_result
[
'pos_inds'
][
b
].
unsqueeze
(
-
1
).
repeat
(
1
,
column_assign_result
[
'pos_inds'
][
b
].
shape
[
0
])
cs
=
column_assign_result
[
'pos_inds'
][
b
].
unsqueeze
(
0
).
repeat
(
row_assign_result
[
'pos_inds'
][
b
].
shape
[
0
],
1
)
cs
=
column_assign_result
[
'pos_inds'
][
b
].
unsqueeze
(
0
).
repeat
(
row_assign_result
[
'pos_inds'
][
b
].
shape
[
0
],
1
)
target
[
rs
,
cs
]
=
gt_adj
[
b
][
row_assign_result
[
'pos_assigned_gt_inds'
][
b
]][:,
column_assign_result
[
'pos_assigned_gt_inds'
][
b
]].
float
()
target
[
rs
,
cs
]
=
gt_adj
[
b
][
row_assign_result
[
'pos_assigned_gt_inds'
][
b
]][:,
column_assign_result
[
'pos_assigned_gt_inds'
][
b
]].
float
()
targets
.
append
(
target
)
targets
.
append
(
target
)
targets
=
1
-
torch
.
stack
(
targets
,
dim
=
0
)
# 0 as positive
targets
=
1
-
torch
.
stack
(
targets
,
dim
=
0
)
# 0 as positive
loss_dict
=
dict
()
loss_dict
=
dict
()
pred_adj
=
pred_adj
.
reshape
(
-
1
,
self
.
out_channels
)
pred_adj
=
pred_adj
.
reshape
(
-
1
,
self
.
out_channels
)
...
...
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/__init__.py
View file @
41b18fd8
from
.spatial_cross_attention
import
SpatialCrossAttention
,
MSDeformableAttention3D
from
.temporal_self_attention
import
TemporalSelfAttention
from
.encoder
import
BEVFormerEncoder
,
BEVFormerLayer
from
.decoder
import
LaneDetectionTransformerDecoder
from
.bevformer_constructer
import
BEVFormerConstructer
from
.transformer
import
PerceptionTransformer
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/bevformer_constructer.py
View file @
41b18fd8
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.init
import
normal_
from
torchvision.transforms.functional
import
rotate
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.transformer
import
build_transformer_layer_sequence
,
build_positional_encoding
from
mmcv.cnn.bricks.transformer
import
(
build_positional_encoding
,
build_transformer_layer_sequence
)
from
mmcv.runner.base_module
import
BaseModule
from
mmcv.runner.base_module
import
BaseModule
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmdet.models.utils.builder
import
TRANSFORMER
from
mmdet3d.models
import
NECKS
from
mmdet3d.models
import
NECKS
from
torch.nn.init
import
normal_
from
torchvision.transforms.functional
import
rotate
from
.temporal_self_attention
import
TemporalSelfAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.decoder
import
CustomMSDeformableAttention
from
.decoder
import
CustomMSDeformableAttention
from
.spatial_cross_attention
import
MSDeformableAttention3D
from
.temporal_self_attention
import
TemporalSelfAttention
@
NECKS
.
register_module
()
@
NECKS
.
register_module
()
...
@@ -69,7 +67,7 @@ class BEVFormerConstructer(BaseModule):
...
@@ -69,7 +67,7 @@ class BEVFormerConstructer(BaseModule):
def
init_layers
(
self
):
def
init_layers
(
self
):
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_embedding
=
nn
.
Embedding
(
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
self
.
bev_h
*
self
.
bev_w
,
self
.
embed_dims
)
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
level_embeds
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
num_feature_levels
,
self
.
embed_dims
))
self
.
cams_embeds
=
nn
.
Parameter
(
self
.
cams_embeds
=
nn
.
Parameter
(
...
@@ -82,7 +80,7 @@ class BEVFormerConstructer(BaseModule):
...
@@ -82,7 +80,7 @@ class BEVFormerConstructer(BaseModule):
)
)
if
self
.
can_bus_norm
:
if
self
.
can_bus_norm
:
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
self
.
can_bus_mlp
.
add_module
(
'norm'
,
nn
.
LayerNorm
(
self
.
embed_dims
))
def
init_weights
(
self
):
def
init_weights
(
self
):
"""Initialize the transformer weights."""
"""Initialize the transformer weights."""
for
p
in
self
.
parameters
():
for
p
in
self
.
parameters
():
...
@@ -117,9 +115,9 @@ class BEVFormerConstructer(BaseModule):
...
@@ -117,9 +115,9 @@ class BEVFormerConstructer(BaseModule):
# obtain rotation angle and shift with ego motion
# obtain rotation angle and shift with ego motion
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
delta_x
=
np
.
array
([
each
[
'can_bus'
][
0
]
for
each
in
img_metas
])
for
each
in
img_metas
])
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
delta_y
=
np
.
array
([
each
[
'can_bus'
][
1
]
for
each
in
img_metas
])
for
each
in
img_metas
])
ego_angle
=
np
.
array
(
ego_angle
=
np
.
array
(
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
img_metas
])
[
each
[
'can_bus'
][
-
2
]
/
np
.
pi
*
180
for
each
in
img_metas
])
...
@@ -129,9 +127,9 @@ class BEVFormerConstructer(BaseModule):
...
@@ -129,9 +127,9 @@ class BEVFormerConstructer(BaseModule):
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
translation_angle
=
np
.
arctan2
(
delta_y
,
delta_x
)
/
np
.
pi
*
180
bev_angle
=
ego_angle
-
translation_angle
bev_angle
=
ego_angle
-
translation_angle
shift_y
=
translation_length
*
\
shift_y
=
translation_length
*
\
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
self
.
bev_h
np
.
cos
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_y
/
self
.
bev_h
shift_x
=
translation_length
*
\
shift_x
=
translation_length
*
\
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
self
.
bev_w
np
.
sin
(
bev_angle
/
180
*
np
.
pi
)
/
grid_length_x
/
self
.
bev_w
shift_y
=
shift_y
*
self
.
use_shift
shift_y
=
shift_y
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift_x
=
shift_x
*
self
.
use_shift
shift
=
bev_queries
.
new_tensor
(
shift
=
bev_queries
.
new_tensor
(
...
@@ -167,7 +165,7 @@ class BEVFormerConstructer(BaseModule):
...
@@ -167,7 +165,7 @@ class BEVFormerConstructer(BaseModule):
if
self
.
use_cams_embeds
:
if
self
.
use_cams_embeds
:
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
cams_embeds
[:,
None
,
None
,
:].
to
(
feat
.
dtype
)
feat
=
feat
+
self
.
level_embeds
[
None
,
feat
=
feat
+
self
.
level_embeds
[
None
,
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
None
,
lvl
:
lvl
+
1
,
:].
to
(
feat
.
dtype
)
spatial_shapes
.
append
(
spatial_shape
)
spatial_shapes
.
append
(
spatial_shape
)
feat_flatten
.
append
(
feat
)
feat_flatten
.
append
(
feat
)
...
@@ -196,4 +194,3 @@ class BEVFormerConstructer(BaseModule):
...
@@ -196,4 +194,3 @@ class BEVFormerConstructer(BaseModule):
)
)
return
bev_embed
return
bev_embed
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/custom_base_transformer_layer.py
View file @
41b18fd8
# ---------------------------------------------
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# ---------------------------------------------
# Modified by Zhiqi Li
# Modified by Zhiqi Li
# ---------------------------------------------
# ---------------------------------------------
import
copy
import
copy
import
warnings
import
warnings
import
torch
import
torch
import
torch.nn
as
nn
from
mmcv
import
ConfigDict
from
mmcv.cnn
import
build_norm_layer
from
mmcv
import
ConfigDict
,
deprecated_api_warning
from
mmcv.cnn.bricks.registry
import
TRANSFORMER_LAYER
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
build_norm_layer
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
FEEDFORWARD_NETWORK
,
POSITIONAL_ENCODING
,
try
:
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.ops.multi_scale_deform_attn
import
\
MultiScaleDeformableAttention
# noqa F401
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try
:
warnings
.
warn
(
from
mmcv.ops.multi_scale_deform_attn
import
MultiScaleDeformableAttention
# noqa F401
ImportWarning
(
warnings
.
warn
(
'``MultiScaleDeformableAttention`` has been moved to '
ImportWarning
(
'``mmcv.ops.multi_scale_deform_attn``, please change original path '
# noqa E501
'``MultiScaleDeformableAttention`` has been moved to '
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` '
# noqa E501
'``mmcv.ops.multi_scale_deform_attn``, please change original path '
# noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` '
# noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` '
# noqa E501
))
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` '
# noqa E501
except
ImportError
:
))
warnings
.
warn
(
'Fail to import ``MultiScaleDeformableAttention`` from '
except
ImportError
:
'``mmcv.ops.multi_scale_deform_attn``, '
warnings
.
warn
(
'Fail to import ``MultiScaleDeformableAttention`` from '
'You should install ``mmcv-full`` if you need this module. '
)
'``mmcv.ops.multi_scale_deform_attn``, '
from
mmcv.cnn.bricks.transformer
import
(
build_attention
,
'You should install ``mmcv-full`` if you need this module. '
)
build_feedforward_network
)
from
mmcv.cnn.bricks.transformer
import
build_feedforward_network
,
build_attention
@
TRANSFORMER_LAYER
.
register_module
()
@
TRANSFORMER_LAYER
.
register_module
()
class
MyCustomBaseTransformerLayer
(
BaseModule
):
class
MyCustomBaseTransformerLayer
(
BaseModule
):
"""Base `TransformerLayer` for vision transformer.
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
If it is a dict, all of the attention modules in operation_order
will be built with this config.
will be built with this config.
operation_order (tuple[str]): The execution order of operation
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Support `prenorm` when you specifying first element as `norm`.
Default:None.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
Default: None.
batch_first (bool): Key, Query and Value are shape
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
or (n, batch, embed_dim). Default to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
attn_cfgs
=
None
,
attn_cfgs
=
None
,
ffn_cfgs
=
dict
(
ffn_cfgs
=
dict
(
type
=
'FFN'
,
type
=
'FFN'
,
embed_dims
=
256
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
num_fcs
=
2
,
ffn_drop
=
0.
,
ffn_drop
=
0.
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
),
),
operation_order
=
None
,
operation_order
=
None
,
norm_cfg
=
dict
(
type
=
'LN'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
,
init_cfg
=
None
,
batch_first
=
True
,
batch_first
=
True
,
**
kwargs
):
**
kwargs
):
deprecated_args
=
dict
(
deprecated_args
=
dict
(
feedforward_channels
=
'feedforward_channels'
,
feedforward_channels
=
'feedforward_channels'
,
ffn_dropout
=
'ffn_drop'
,
ffn_dropout
=
'ffn_drop'
,
ffn_num_fcs
=
'num_fcs'
)
ffn_num_fcs
=
'num_fcs'
)
for
ori_name
,
new_name
in
deprecated_args
.
items
():
for
ori_name
,
new_name
in
deprecated_args
.
items
():
if
ori_name
in
kwargs
:
if
ori_name
in
kwargs
:
warnings
.
warn
(
warnings
.
warn
(
f
'The arguments `
{
ori_name
}
` in BaseTransformerLayer '
f
'The arguments `
{
ori_name
}
` in BaseTransformerLayer '
f
'has been deprecated, now you should set `
{
new_name
}
` '
f
'has been deprecated, now you should set `
{
new_name
}
` '
f
'and other FFN related arguments '
f
'and other FFN related arguments '
f
'to a dict named `ffn_cfgs`. '
)
f
'to a dict named `ffn_cfgs`. '
)
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
ffn_cfgs
[
new_name
]
=
kwargs
[
ori_name
]
super
(
MyCustomBaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
super
(
MyCustomBaseTransformerLayer
,
self
).
__init__
(
init_cfg
)
self
.
batch_first
=
batch_first
self
.
batch_first
=
batch_first
assert
set
(
operation_order
)
&
set
(
assert
set
(
operation_order
)
&
set
(
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
])
==
\
set
(
operation_order
),
f
'The operation_order of'
\
set
(
operation_order
),
f
'The operation_order of'
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'
{
self
.
__class__
.
__name__
}
should '
\
f
'contains all four operation type '
\
f
'contains all four operation type '
\
f
"
{
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
]
}
"
f
"
{
[
'self_attn'
,
'norm'
,
'ffn'
,
'cross_attn'
]
}
"
num_attn
=
operation_order
.
count
(
'self_attn'
)
+
operation_order
.
count
(
num_attn
=
operation_order
.
count
(
'self_attn'
)
+
operation_order
.
count
(
'cross_attn'
)
'cross_attn'
)
if
isinstance
(
attn_cfgs
,
dict
):
if
isinstance
(
attn_cfgs
,
dict
):
attn_cfgs
=
[
copy
.
deepcopy
(
attn_cfgs
)
for
_
in
range
(
num_attn
)]
attn_cfgs
=
[
copy
.
deepcopy
(
attn_cfgs
)
for
_
in
range
(
num_attn
)]
else
:
else
:
assert
num_attn
==
len
(
attn_cfgs
),
f
'The length '
\
assert
num_attn
==
len
(
attn_cfgs
),
f
'The length '
\
f
'of attn_cfg
{
num_attn
}
is '
\
f
'of attn_cfg
{
num_attn
}
is '
\
f
'not consistent with the number of attention'
\
f
'not consistent with the number of attention'
\
f
'in operation_order
{
operation_order
}
.'
f
'in operation_order
{
operation_order
}
.'
self
.
num_attn
=
num_attn
self
.
num_attn
=
num_attn
self
.
operation_order
=
operation_order
self
.
operation_order
=
operation_order
self
.
norm_cfg
=
norm_cfg
self
.
norm_cfg
=
norm_cfg
self
.
pre_norm
=
operation_order
[
0
]
==
'norm'
self
.
pre_norm
=
operation_order
[
0
]
==
'norm'
self
.
attentions
=
ModuleList
()
self
.
attentions
=
ModuleList
()
index
=
0
index
=
0
for
operation_name
in
operation_order
:
for
operation_name
in
operation_order
:
if
operation_name
in
[
'self_attn'
,
'cross_attn'
]:
if
operation_name
in
[
'self_attn'
,
'cross_attn'
]:
if
'batch_first'
in
attn_cfgs
[
index
]:
if
'batch_first'
in
attn_cfgs
[
index
]:
assert
self
.
batch_first
==
attn_cfgs
[
index
][
'batch_first'
]
assert
self
.
batch_first
==
attn_cfgs
[
index
][
'batch_first'
]
else
:
else
:
attn_cfgs
[
index
][
'batch_first'
]
=
self
.
batch_first
attn_cfgs
[
index
][
'batch_first'
]
=
self
.
batch_first
attention
=
build_attention
(
attn_cfgs
[
index
])
attention
=
build_attention
(
attn_cfgs
[
index
])
# Some custom attentions used as `self_attn`
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
# or `cross_attn` can have different behavior.
attention
.
operation_name
=
operation_name
attention
.
operation_name
=
operation_name
self
.
attentions
.
append
(
attention
)
self
.
attentions
.
append
(
attention
)
index
+=
1
index
+=
1
self
.
embed_dims
=
self
.
attentions
[
0
].
embed_dims
self
.
embed_dims
=
self
.
attentions
[
0
].
embed_dims
self
.
ffns
=
ModuleList
()
self
.
ffns
=
ModuleList
()
num_ffns
=
operation_order
.
count
(
'ffn'
)
num_ffns
=
operation_order
.
count
(
'ffn'
)
if
isinstance
(
ffn_cfgs
,
dict
):
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
ConfigDict
(
ffn_cfgs
)
ffn_cfgs
=
ConfigDict
(
ffn_cfgs
)
if
isinstance
(
ffn_cfgs
,
dict
):
if
isinstance
(
ffn_cfgs
,
dict
):
ffn_cfgs
=
[
copy
.
deepcopy
(
ffn_cfgs
)
for
_
in
range
(
num_ffns
)]
ffn_cfgs
=
[
copy
.
deepcopy
(
ffn_cfgs
)
for
_
in
range
(
num_ffns
)]
assert
len
(
ffn_cfgs
)
==
num_ffns
assert
len
(
ffn_cfgs
)
==
num_ffns
for
ffn_index
in
range
(
num_ffns
):
for
ffn_index
in
range
(
num_ffns
):
if
'embed_dims'
not
in
ffn_cfgs
[
ffn_index
]:
if
'embed_dims'
not
in
ffn_cfgs
[
ffn_index
]:
ffn_cfgs
[
'embed_dims'
]
=
self
.
embed_dims
ffn_cfgs
[
'embed_dims'
]
=
self
.
embed_dims
else
:
else
:
assert
ffn_cfgs
[
ffn_index
][
'embed_dims'
]
==
self
.
embed_dims
assert
ffn_cfgs
[
ffn_index
][
'embed_dims'
]
==
self
.
embed_dims
self
.
ffns
.
append
(
self
.
ffns
.
append
(
build_feedforward_network
(
ffn_cfgs
[
ffn_index
]))
build_feedforward_network
(
ffn_cfgs
[
ffn_index
]))
self
.
norms
=
ModuleList
()
self
.
norms
=
ModuleList
()
num_norms
=
operation_order
.
count
(
'norm'
)
num_norms
=
operation_order
.
count
(
'norm'
)
for
_
in
range
(
num_norms
):
for
_
in
range
(
num_norms
):
self
.
norms
.
append
(
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
self
.
norms
.
append
(
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)[
1
])
def
forward
(
self
,
def
forward
(
self
,
query
,
query
,
key
=
None
,
key
=
None
,
value
=
None
,
value
=
None
,
query_pos
=
None
,
query_pos
=
None
,
key_pos
=
None
,
key_pos
=
None
,
attn_masks
=
None
,
attn_masks
=
None
,
query_key_padding_mask
=
None
,
query_key_padding_mask
=
None
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
**
kwargs
):
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
**kwargs contains some specific arguments of attentions.
Args:
Args:
query (Tensor): The input query with shape
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
[num_queries, bs, embed_dims] if
self.batch_first is False, else
self.batch_first is False, else
[bs, num_queries embed_dims].
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
it should equal to the number of `attention` in
`operation_order`. Default: None.
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
shape [bs, num_keys]. Default: None.
Returns:
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
"""
norm_index
=
0
norm_index
=
0
attn_index
=
0
attn_index
=
0
ffn_index
=
0
ffn_index
=
0
identity
=
query
identity
=
query
if
attn_masks
is
None
:
if
attn_masks
is
None
:
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
attn_masks
=
[
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
]
]
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
f
'
{
self
.
__class__
.
__name__
}
'
)
f
'
{
self
.
__class__
.
__name__
}
'
)
else
:
else
:
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'to the number of attention in '
\
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
for
layer
in
self
.
operation_order
:
if
layer
==
'self_attn'
:
if
layer
==
'self_attn'
:
temp_key
=
temp_value
=
query
temp_key
=
temp_value
=
query
query
=
self
.
attentions
[
attn_index
](
query
=
self
.
attentions
[
attn_index
](
query
,
query
,
temp_key
,
temp_key
,
temp_value
,
temp_value
,
identity
if
self
.
pre_norm
else
None
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
query_pos
=
query_pos
,
key_pos
=
query_pos
,
key_pos
=
query_pos
,
attn_mask
=
attn_masks
[
attn_index
],
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
query_key_padding_mask
,
key_padding_mask
=
query_key_padding_mask
,
**
kwargs
)
**
kwargs
)
attn_index
+=
1
attn_index
+=
1
identity
=
query
identity
=
query
elif
layer
==
'norm'
:
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
norm_index
+=
1
elif
layer
==
'cross_attn'
:
elif
layer
==
'cross_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
=
self
.
attentions
[
attn_index
](
query
,
query
,
key
,
key
,
value
,
value
,
identity
if
self
.
pre_norm
else
None
,
identity
if
self
.
pre_norm
else
None
,
query_pos
=
query_pos
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
key_pos
=
key_pos
,
attn_mask
=
attn_masks
[
attn_index
],
attn_mask
=
attn_masks
[
attn_index
],
key_padding_mask
=
key_padding_mask
,
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
**
kwargs
)
attn_index
+=
1
attn_index
+=
1
identity
=
query
identity
=
query
elif
layer
==
'ffn'
:
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
query
=
self
.
ffns
[
ffn_index
](
query
,
identity
if
self
.
pre_norm
else
None
)
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
ffn_index
+=
1
return
query
return
query
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/decoder.py
View file @
41b18fd8
# ---------------------------------------------
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# ---------------------------------------------
# Modified by Zhiqi Li
# Modified by Zhiqi Li
# ---------------------------------------------
# ---------------------------------------------
from
cmath
import
pi
import
math
from
mmcv.ops.multi_scale_deform_attn
import
multi_scale_deformable_attn_pytorch
import
warnings
import
mmcv
import
cv2
as
cv
import
torch
import
copy
import
torch.nn
as
nn
import
warnings
from
mmcv.cnn
import
constant_init
,
xavier_init
from
matplotlib
import
pyplot
as
plt
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
import
numpy
as
np
TRANSFORMER_LAYER_SEQUENCE
)
import
torch
from
mmcv.cnn.bricks.transformer
import
(
BaseTransformerLayer
,
import
torch.nn
as
nn
TransformerLayerSequence
)
import
torch.nn.functional
as
F
from
mmcv.ops.multi_scale_deform_attn
import
\
from
mmcv.cnn
import
xavier_init
,
constant_init
multi_scale_deformable_attn_pytorch
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
TRANSFORMER_LAYER
,
from
mmcv.runner.base_module
import
BaseModule
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.utils
import
deprecated_api_warning
,
ext_loader
from
mmcv.cnn.bricks.transformer
import
BaseTransformerLayer
,
TransformerLayerSequence
import
math
from
.multi_scale_deformable_attn_function
import
\
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
,
Sequential
MultiScaleDeformableAttnFunction_fp32
from
mmcv.utils
import
(
ConfigDict
,
build_from_cfg
,
deprecated_api_warning
,
to_2tuple
)
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
from
mmcv.utils
import
ext_loader
from
.multi_scale_deformable_attn_function
import
MultiScaleDeformableAttnFunction_fp32
,
\
MultiScaleDeformableAttnFunction_fp16
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
"""Inverse function of sigmoid.
ext_module
=
ext_loader
.
load_ext
(
Args:
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
def
inverse_sigmoid
(
x
,
eps
=
1e-5
):
overflow. Defaults 1e-5.
"""Inverse function of sigmoid.
Returns:
Args:
Tensor: The x has passed the inverse
x (Tensor): The tensor to do the
function of sigmoid, has same
inverse.
shape with input.
eps (float): EPS avoid numerical
"""
overflow. Defaults 1e-5.
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
Returns:
x1
=
x
.
clamp
(
min
=
eps
)
Tensor: The x has passed the inverse
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
function of sigmoid, has same
return
torch
.
log
(
x1
/
x2
)
shape with input.
"""
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
x1
=
x
.
clamp
(
min
=
eps
)
class
LaneDetectionTransformerDecoder
(
TransformerLayerSequence
):
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
**
kwargs
):
super
(
LaneDetectionTransformerDecoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
return_intermediate
=
return_intermediate
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
self
.
fp16_enabled
=
False
class
LaneDetectionTransformerDecoder
(
TransformerLayerSequence
):
def
forward
(
self
,
def
__init__
(
self
,
*
args
,
return_intermediate
=
False
,
**
kwargs
):
query
,
super
(
LaneDetectionTransformerDecoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
*
args
,
self
.
return_intermediate
=
return_intermediate
reference_points
=
None
,
self
.
fp16_enabled
=
False
reg_branches
=
None
,
key_padding_mask
=
None
,
def
forward
(
self
,
**
kwargs
):
query
,
"""Forward function for `Detr3DTransformerDecoder`.
*
args
,
Args:
reference_points
=
None
,
query (Tensor): Input query with shape
reg_branches
=
None
,
`(num_query, bs, embed_dims)`.
key_padding_mask
=
None
,
reference_points (Tensor): The reference
**
kwargs
):
points of offset. has shape
"""Forward function for `Detr3DTransformerDecoder`.
(bs, num_query, 4) when as_two_stage,
Args:
otherwise has shape ((bs, num_query, 2).
query (Tensor): Input query with shape
reg_branch: (obj:`nn.ModuleList`): Used for
`(num_query, bs, embed_dims)`.
refining the regression results. Only would
reference_points (Tensor): The reference
be passed when with_box_refine is True,
points of offset. has shape
otherwise would be passed a `None`.
(bs, num_query, 4) when as_two_stage,
Returns:
otherwise has shape ((bs, num_query, 2).
Tensor: Results with shape [1, num_query, bs, embed_dims] when
reg_branch: (obj:`nn.ModuleList`): Used for
return_intermediate is `False`, otherwise it has shape
refining the regression results. Only would
[num_layers, num_query, bs, embed_dims].
be passed when with_box_refine is True,
"""
otherwise would be passed a `None`.
Returns:
output
=
query
Tensor: Results with shape [1, num_query, bs, embed_dims] when
intermediate
=
[]
return_intermediate is `False`, otherwise it has shape
intermediate_reference_points
=
[]
[num_layers, num_query, bs, embed_dims].
for
lid
,
layer
in
enumerate
(
self
.
layers
):
"""
reference_points_input
=
reference_points
[...,
:
2
].
unsqueeze
(
2
)
# BS NUM_QUERY NUM_LEVEL 2
output
=
query
output
=
layer
(
intermediate
=
[]
output
,
intermediate_reference_points
=
[]
*
args
,
for
lid
,
layer
in
enumerate
(
self
.
layers
):
reference_points
=
reference_points_input
,
reference_points_input
=
reference_points
[...,
:
2
].
unsqueeze
(
key_padding_mask
=
key_padding_mask
,
2
)
# BS NUM_QUERY NUM_LEVEL 2
**
kwargs
)
output
=
layer
(
output
=
output
.
permute
(
1
,
0
,
2
)
output
,
*
args
,
if
reg_branches
is
not
None
:
reference_points
=
reference_points_input
,
tmp
=
reg_branches
[
lid
](
output
)
key_padding_mask
=
key_padding_mask
,
**
kwargs
)
assert
reference_points
.
shape
[
-
1
]
==
3
output
=
output
.
permute
(
1
,
0
,
2
)
new_reference_points
=
torch
.
zeros_like
(
reference_points
)
if
reg_branches
is
not
None
:
ref_center
=
(
tmp
[...,
:
3
]
+
tmp
[...,
-
3
:])
/
2
new_reference_points
=
ref_center
+
inverse_sigmoid
(
reference_points
)
tmp
=
reg_branches
[
lid
](
output
)
new_reference_points
=
new_reference_points
.
sigmoid
()
assert
reference_points
.
shape
[
-
1
]
==
3
reference_points
=
new_reference_points
.
detach
()
new_reference_points
=
torch
.
zeros_like
(
reference_points
)
output
=
output
.
permute
(
1
,
0
,
2
)
ref_center
=
(
tmp
[...,
:
3
]
+
tmp
[...,
-
3
:])
/
2
if
self
.
return_intermediate
:
new_reference_points
=
ref_center
+
inverse_sigmoid
(
reference_points
)
intermediate
.
append
(
output
)
new_reference_points
=
new_reference_points
.
sigmoid
()
intermediate_reference_points
.
append
(
reference_points
)
reference_points
=
new_reference_points
.
detach
()
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
output
=
output
.
permute
(
1
,
0
,
2
)
intermediate_reference_points
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
return
output
,
reference_points
intermediate_reference_points
.
append
(
reference_points
)
if
self
.
return_intermediate
:
@
TRANSFORMER_LAYER
.
register_module
()
return
torch
.
stack
(
intermediate
),
torch
.
stack
(
class
CustomDetrTransformerDecoderLayer
(
BaseTransformerLayer
):
intermediate_reference_points
)
"""Implements decoder layer in DETR transformer.
return
output
,
reference_points
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
@
TRANSFORMER_LAYER
.
register_module
()
should be consistent with it in `operation_order`. If it is
class
CustomDetrTransformerDecoderLayer
(
BaseTransformerLayer
):
a dict, it would be expand to the number of attention in
"""Implements decoder layer in DETR transformer.
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
Args:
ffn_dropout (float): Probability of an element to be zeroed
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
in ffn. Default 0.0.
Configs for self_attention or cross_attention, the order
operation_order (tuple[str]): The execution order of operation
should be consistent with it in `operation_order`. If it is
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
a dict, it would be expand to the number of attention in
Default:None
`operation_order`.
act_cfg (dict): The activation config for FFNs. Default: `LN`
feedforward_channels (int): The hidden dimension for FFNs.
norm_cfg (dict): Config dict for normalization layer.
ffn_dropout (float): Probability of an element to be zeroed
Default: `LN`.
in ffn. Default 0.0.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
operation_order (tuple[str]): The execution order of operation
Default:2.
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
"""
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
def
__init__
(
self
,
norm_cfg (dict): Config dict for normalization layer.
attn_cfgs
,
Default: `LN`.
ffn_cfgs
,
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
operation_order
=
None
,
Default:2.
norm_cfg
=
dict
(
type
=
'LN'
),
"""
**
kwargs
):
super
(
CustomDetrTransformerDecoderLayer
,
self
).
__init__
(
def
__init__
(
self
,
attn_cfgs
=
attn_cfgs
,
attn_cfgs
,
ffn_cfgs
=
ffn_cfgs
,
ffn_cfgs
,
operation_order
=
operation_order
,
operation_order
=
None
,
norm_cfg
=
norm_cfg
,
norm_cfg
=
dict
(
type
=
'LN'
),
**
kwargs
)
**
kwargs
):
assert
len
(
operation_order
)
==
6
super
(
CustomDetrTransformerDecoderLayer
,
self
).
__init__
(
assert
set
(
operation_order
)
==
set
(
attn_cfgs
=
attn_cfgs
,
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
ffn_cfgs
=
ffn_cfgs
,
operation_order
=
operation_order
,
norm_cfg
=
norm_cfg
,
@
ATTENTION
.
register_module
()
**
kwargs
)
class
CustomMSDeformableAttention
(
BaseModule
):
assert
len
(
operation_order
)
==
6
"""An attention module used in Deformable-Detr.
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
@
ATTENTION
.
register_module
()
Args:
class
CustomMSDeformableAttention
(
BaseModule
):
embed_dims (int): The embedding dimension of Attention.
"""An attention module used in Deformable-Detr.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
num_levels (int): The number of feature map used in
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Attention. Default: 4.
num_points (int): The number of sampling points for
Args:
each query in each head. Default: 4.
embed_dims (int): The embedding dimension of Attention.
im2col_step (int): The step used in image_to_column.
Default: 256.
Default: 64.
num_heads (int): Parallel attention heads. Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
num_levels (int): The number of feature map used in
Default: 0.1.
Attention. Default: 4.
batch_first (bool): Key, Query and Value are shape of
num_points (int): The number of sampling points for
(batch, n, embed_dim)
each query in each head. Default: 4.
or (n, batch, embed_dim). Default to False.
im2col_step (int): The step used in image_to_column.
norm_cfg (dict): Config dict for normalization layer.
Default: 64.
Default: None.
dropout (float): A Dropout layer on `inp_identity`.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: 0.1.
Default: None.
batch_first (bool): Key, Query and Value are shape of
"""
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
def
__init__
(
self
,
norm_cfg (dict): Config dict for normalization layer.
embed_dims
=
256
,
Default: None.
num_heads
=
8
,
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
num_levels
=
4
,
Default: None.
num_points
=
4
,
"""
im2col_step
=
64
,
dropout
=
0.1
,
def
__init__
(
self
,
batch_first
=
False
,
embed_dims
=
256
,
norm_cfg
=
None
,
num_heads
=
8
,
init_cfg
=
None
):
num_levels
=
4
,
super
().
__init__
(
init_cfg
)
num_points
=
4
,
if
embed_dims
%
num_heads
!=
0
:
im2col_step
=
64
,
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
dropout
=
0.1
,
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
batch_first
=
False
,
dim_per_head
=
embed_dims
//
num_heads
norm_cfg
=
None
,
self
.
norm_cfg
=
norm_cfg
init_cfg
=
None
):
self
.
dropout
=
nn
.
Dropout
(
dropout
)
super
().
__init__
(
init_cfg
)
self
.
batch_first
=
batch_first
if
embed_dims
%
num_heads
!=
0
:
self
.
fp16_enabled
=
False
raise
ValueError
(
f
'embed_dims must be divisible by num_heads, '
f
'but got
{
embed_dims
}
and
{
num_heads
}
'
)
# you'd better set dim_per_head to a power of 2
dim_per_head
=
embed_dims
//
num_heads
# which is more efficient in the CUDA implementation
self
.
norm_cfg
=
norm_cfg
def
_is_power_of_2
(
n
):
self
.
dropout
=
nn
.
Dropout
(
dropout
)
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
self
.
batch_first
=
batch_first
raise
ValueError
(
self
.
fp16_enabled
=
False
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
n
,
type
(
n
)))
# you'd better set dim_per_head to a power of 2
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
# which is more efficient in the CUDA implementation
def
_is_power_of_2
(
n
):
if
not
_is_power_of_2
(
dim_per_head
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
warnings
.
warn
(
raise
ValueError
(
"You'd better set embed_dims in "
'invalid input for _is_power_of_2: {} (type: {})'
.
format
(
'MultiScaleDeformAttention to make '
n
,
type
(
n
)))
'the dimension of each attention head a power of 2 '
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
'which is more efficient in our CUDA implementation.'
)
if
not
_is_power_of_2
(
dim_per_head
):
self
.
im2col_step
=
im2col_step
warnings
.
warn
(
self
.
embed_dims
=
embed_dims
"You'd better set embed_dims in "
self
.
num_levels
=
num_levels
'MultiScaleDeformAttention to make '
self
.
num_heads
=
num_heads
'the dimension of each attention head a power of 2 '
self
.
num_points
=
num_points
'which is more efficient in our CUDA implementation.'
)
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
im2col_step
=
im2col_step
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
self
.
embed_dims
=
embed_dims
num_heads
*
num_levels
*
num_points
)
self
.
num_levels
=
num_levels
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
num_heads
=
num_heads
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
num_points
=
num_points
self
.
init_weights
()
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dims
,
num_heads
*
num_levels
*
num_points
*
2
)
def
init_weights
(
self
):
self
.
attention_weights
=
nn
.
Linear
(
embed_dims
,
"""Default initialization for Parameters of Module."""
num_heads
*
num_levels
*
num_points
)
constant_init
(
self
.
sampling_offsets
,
0.
)
self
.
value_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
thetas
=
torch
.
arange
(
self
.
output_proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
num_heads
,
self
.
init_weights
()
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
def
init_weights
(
self
):
grid_init
=
(
grid_init
/
"""Default initialization for Parameters of Module."""
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
constant_init
(
self
.
sampling_offsets
,
0.
)
self
.
num_heads
,
1
,
1
,
thetas
=
torch
.
arange
(
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
self
.
num_heads
,
for
i
in
range
(
self
.
num_points
):
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
[:,
:,
i
,
:]
*=
i
+
1
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
grid_init
/
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]).
view
(
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
self
.
num_heads
,
1
,
1
,
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
2
).
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
for
i
in
range
(
self
.
num_points
):
self
.
_is_init
=
True
grid_init
[:,
:,
i
,
:]
*=
i
+
1
@
deprecated_api_warning
({
'residual'
:
'identity'
},
self
.
sampling_offsets
.
bias
.
data
=
grid_init
.
view
(
-
1
)
cls_name
=
'MultiScaleDeformableAttention'
)
constant_init
(
self
.
attention_weights
,
val
=
0.
,
bias
=
0.
)
def
forward
(
self
,
xavier_init
(
self
.
value_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
query
,
xavier_init
(
self
.
output_proj
,
distribution
=
'uniform'
,
bias
=
0.
)
key
=
None
,
self
.
_is_init
=
True
value
=
None
,
identity
=
None
,
@
deprecated_api_warning
({
'residual'
:
'identity'
},
query_pos
=
None
,
cls_name
=
'MultiScaleDeformableAttention'
)
key_padding_mask
=
None
,
def
forward
(
self
,
reference_points
=
None
,
query
,
spatial_shapes
=
None
,
key
=
None
,
level_start_index
=
None
,
value
=
None
,
flag
=
'decoder'
,
identity
=
None
,
**
kwargs
):
query_pos
=
None
,
"""Forward Function of MultiScaleDeformAttention.
key_padding_mask
=
None
,
reference_points
=
None
,
Args:
spatial_shapes
=
None
,
query (Tensor): Query of Transformer with shape
level_start_index
=
None
,
(num_query, bs, embed_dims).
flag
=
'decoder'
,
key (Tensor): The key tensor with shape
**
kwargs
):
`(num_key, bs, embed_dims)`.
"""Forward Function of MultiScaleDeformAttention.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
Args:
identity (Tensor): The tensor used for addition, with the
query (Tensor): Query of Transformer with shape
same shape as `query`. Default None. If None,
(num_query, bs, embed_dims).
`query` will be used.
key (Tensor): The key tensor with shape
query_pos (Tensor): The positional encoding for `query`.
`(num_key, bs, embed_dims)`.
Default: None.
value (Tensor): The value tensor with shape
key_pos (Tensor): The positional encoding for `key`. Default
`(num_key, bs, embed_dims)`.
None.
identity (Tensor): The tensor used for addition, with the
reference_points (Tensor): The normalized reference
same shape as `query`. Default None. If None,
points with shape (bs, num_query, num_levels, 2),
`query` will be used.
all elements is range in [0, 1], top-left (0,0),
query_pos (Tensor): The positional encoding for `query`.
bottom-right (1, 1), including padding area.
Default: None.
or (N, Length_{query}, num_levels, 4), add
key_pos (Tensor): The positional encoding for `key`. Default
additional two dimensions is (w, h) to
None.
form reference boxes.
reference_points (Tensor): The normalized reference
key_padding_mask (Tensor): ByteTensor for `query`, with
points with shape (bs, num_query, num_levels, 2),
shape [bs, num_key].
all elements is range in [0, 1], top-left (0,0),
spatial_shapes (Tensor): Spatial shape of features in
bottom-right (1, 1), including padding area.
different levels. With shape (num_levels, 2),
or (N, Length_{query}, num_levels, 4), add
last dimension represents (h, w).
additional two dimensions is (w, h) to
level_start_index (Tensor): The start index of each level.
form reference boxes.
A tensor has shape ``(num_levels, )`` and can be represented
key_padding_mask (Tensor): ByteTensor for `query`, with
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
Returns:
different levels. With shape (num_levels, 2),
Tensor: forwarded results with shape [num_query, bs, embed_dims].
last dimension represents (h, w).
"""
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
if
value
is
None
:
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
value
=
query
Returns:
if
identity
is
None
:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
identity
=
query
"""
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
value
is
None
:
if
not
self
.
batch_first
:
value
=
query
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
if
identity
is
None
:
value
=
value
.
permute
(
1
,
0
,
2
)
identity
=
query
if
query_pos
is
not
None
:
bs
,
num_query
,
_
=
query
.
shape
query
=
query
+
query_pos
bs
,
num_value
,
_
=
value
.
shape
if
not
self
.
batch_first
:
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
self
.
value_proj
(
value
)
value
=
value
.
permute
(
1
,
0
,
2
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
bs
,
num_query
,
_
=
query
.
shape
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
value
=
self
.
value_proj
(
value
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
if
key_padding_mask
is
not
None
:
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
0.0
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
self
.
num_heads
,
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
self
.
num_levels
,
attention_weights
=
self
.
attention_weights
(
query
).
view
(
self
.
num_points
)
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
if
reference_points
.
shape
[
-
1
]
==
2
:
attention_weights
=
attention_weights
.
softmax
(
-
1
)
offset_normalizer
=
torch
.
stack
(
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
self
.
num_heads
,
+
sampling_offsets
\
self
.
num_levels
,
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
self
.
num_points
)
elif
reference_points
.
shape
[
-
1
]
==
4
:
if
reference_points
.
shape
[
-
1
]
==
2
:
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
offset_normalizer
=
torch
.
stack
(
+
sampling_offsets
/
self
.
num_points
\
[
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:]
\
*
0.5
+
sampling_offsets
\
else
:
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
raise
ValueError
(
elif
reference_points
.
shape
[
-
1
]
==
4
:
f
'Last dim of reference_points must be'
sampling_locations
=
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
\
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
+
sampling_offsets
/
self
.
num_points
\
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
\
*
0.5
# using fp16 deformable attention is unstable because it performs many sum operations
else
:
if
value
.
dtype
==
torch
.
float16
:
raise
ValueError
(
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
f
'Last dim of reference_points must be'
else
:
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
# using fp16 deformable attention is unstable because it performs many sum operations
attention_weights
,
self
.
im2col_step
)
if
value
.
dtype
==
torch
.
float16
:
else
:
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
multi_scale_deformable_attn_pytorch
(
else
:
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
MultiScaleDeformableAttnFunction
=
MultiScaleDeformableAttnFunction_fp32
output
=
MultiScaleDeformableAttnFunction
.
apply
(
output
=
self
.
output_proj
(
output
)
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
if
not
self
.
batch_first
:
else
:
# (num_query, bs ,embed_dims)
output
=
multi_scale_deformable_attn_pytorch
(
output
=
output
.
permute
(
1
,
0
,
2
)
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
return
self
.
dropout
(
output
)
+
identity
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
# (num_query, bs ,embed_dims)
output
=
output
.
permute
(
1
,
0
,
2
)
return
self
.
dropout
(
output
)
+
identity
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/encoder.py
View file @
41b18fd8
# ---------------------------------------------
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# ---------------------------------------------
# Modified by Zhiqi Li
# Modified by Zhiqi Li
# ---------------------------------------------
# ---------------------------------------------
import
copy
from
.custom_base_transformer_layer
import
MyCustomBaseTransformerLayer
import
warnings
import
copy
import
warnings
import
numpy
as
np
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
import
torch
TRANSFORMER_LAYER
,
from
mmcv.cnn.bricks.registry
import
(
TRANSFORMER_LAYER
,
TRANSFORMER_LAYER_SEQUENCE
)
TRANSFORMER_LAYER_SEQUENCE
)
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
from
mmcv.cnn.bricks.transformer
import
TransformerLayerSequence
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.runner
import
auto_fp16
,
force_fp32
import
numpy
as
np
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
,
ext_loader
import
torch
import
cv2
as
cv
from
.custom_base_transformer_layer
import
MyCustomBaseTransformerLayer
import
mmcv
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
ext_module
=
ext_loader
.
load_ext
(
from
mmcv.utils
import
ext_loader
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
class
BEVFormerEncoder
(
TransformerLayerSequence
):
@
TRANSFORMER_LAYER_SEQUENCE
.
register_module
()
"""
class
BEVFormerEncoder
(
TransformerLayerSequence
):
Attention with both self and cross
Implements the decoder in DETR transformer.
"""
Args:
Attention with both self and cross
return_intermediate (bool): Whether to return intermediate outputs.
Implements the decoder in DETR transformer.
coder_norm_cfg (dict): Config of last normalization layer. Default:
Args:
`LN`.
return_intermediate (bool): Whether to return intermediate outputs.
"""
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
def
__init__
(
self
,
*
args
,
pc_range
=
None
,
num_points_in_pillar
=
4
,
return_intermediate
=
False
,
dataset_type
=
'nuscenes'
,
"""
**
kwargs
):
def
__init__
(
self
,
*
args
,
pc_range
=
None
,
num_points_in_pillar
=
4
,
return_intermediate
=
False
,
dataset_type
=
'nuscenes'
,
super
(
BEVFormerEncoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
**
kwargs
):
self
.
return_intermediate
=
return_intermediate
super
(
BEVFormerEncoder
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
num_points_in_pillar
=
num_points_in_pillar
self
.
return_intermediate
=
return_intermediate
self
.
pc_range
=
pc_range
self
.
fp16_enabled
=
False
self
.
num_points_in_pillar
=
num_points_in_pillar
self
.
pc_range
=
pc_range
@
staticmethod
self
.
fp16_enabled
=
False
def
get_reference_points
(
H
,
W
,
Z
=
8
,
num_points_in_pillar
=
4
,
dim
=
'3d'
,
bs
=
1
,
device
=
'cuda'
,
dtype
=
torch
.
float
):
"""Get the reference points used in SCA and TSA.
@
staticmethod
Args:
def
get_reference_points
(
H
,
W
,
Z
=
8
,
num_points_in_pillar
=
4
,
dim
=
'3d'
,
bs
=
1
,
device
=
'cuda'
,
dtype
=
torch
.
float
):
H, W: spatial shape of bev.
"""Get the reference points used in SCA and TSA.
Z: hight of pillar.
Args:
D: sample D points uniformly from each pillar.
H, W: spatial shape of bev.
device (obj:`device`): The device where
Z: hight of pillar.
reference_points should be.
D: sample D points uniformly from each pillar.
Returns:
device (obj:`device`): The device where
Tensor: reference points used in decoder, has
\
reference_points should be.
shape (bs, num_keys, num_levels, 2).
Returns:
"""
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
# reference points in 3D space, used in spatial cross-attention (SCA)
"""
if
dim
==
'3d'
:
zs
=
torch
.
linspace
(
0.5
,
Z
-
0.5
,
num_points_in_pillar
,
dtype
=
dtype
,
# reference points in 3D space, used in spatial cross-attention (SCA)
device
=
device
).
view
(
-
1
,
1
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
Z
if
dim
==
'3d'
:
xs
=
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
zs
=
torch
.
linspace
(
0.5
,
Z
-
0.5
,
num_points_in_pillar
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
1
,
W
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
W
device
=
device
).
view
(
-
1
,
1
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
Z
ys
=
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
xs
=
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
).
view
(
1
,
H
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
H
device
=
device
).
view
(
1
,
1
,
W
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
W
ref_3d
=
torch
.
stack
((
xs
,
ys
,
zs
),
-
1
)
ys
=
torch
.
linspace
(
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
ref_3d
=
ref_3d
.
permute
(
0
,
3
,
1
,
2
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
device
=
device
).
view
(
1
,
H
,
1
).
expand
(
num_points_in_pillar
,
H
,
W
)
/
H
ref_3d
=
ref_3d
[
None
].
repeat
(
bs
,
1
,
1
,
1
)
ref_3d
=
torch
.
stack
((
xs
,
ys
,
zs
),
-
1
)
return
ref_3d
ref_3d
=
ref_3d
.
permute
(
0
,
3
,
1
,
2
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
ref_3d
=
ref_3d
[
None
].
repeat
(
bs
,
1
,
1
,
1
)
# reference points on 2D bev plane, used in temporal self-attention (TSA).
return
ref_3d
elif
dim
==
'2d'
:
ref_y
,
ref_x
=
torch
.
meshgrid
(
# reference points on 2D bev plane, used in temporal self-attention (TSA).
torch
.
linspace
(
elif
dim
==
'2d'
:
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
),
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
torch
.
linspace
(
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
)
0.5
,
H
-
0.5
,
H
,
dtype
=
dtype
,
device
=
device
),
)
torch
.
linspace
(
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
H
0.5
,
W
-
0.5
,
W
,
dtype
=
dtype
,
device
=
device
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
W
)
ref_2d
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
H
ref_2d
=
ref_2d
.
repeat
(
bs
,
1
,
1
).
unsqueeze
(
2
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
W
return
ref_2d
ref_2d
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
ref_2d
=
ref_2d
.
repeat
(
bs
,
1
,
1
).
unsqueeze
(
2
)
# This function must use fp32!!!
return
ref_2d
@
force_fp32
(
apply_to
=
(
'reference_points'
,
'img_metas'
))
def
point_sampling
(
self
,
reference_points
,
pc_range
,
img_metas
):
# This function must use fp32!!!
@
force_fp32
(
apply_to
=
(
'reference_points'
,
'img_metas'
))
lidar2img
=
[]
def
point_sampling
(
self
,
reference_points
,
pc_range
,
img_metas
):
for
img_meta
in
img_metas
:
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
lidar2img
=
[]
lidar2img
=
np
.
asarray
(
lidar2img
)
for
img_meta
in
img_metas
:
lidar2img
=
reference_points
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
lidar2img
.
append
(
img_meta
[
'lidar2img'
])
reference_points
=
reference_points
.
clone
()
lidar2img
=
np
.
asarray
(
lidar2img
)
lidar2img
=
reference_points
.
new_tensor
(
lidar2img
)
# (B, N, 4, 4)
reference_points
[...,
0
:
1
]
=
reference_points
[...,
0
:
1
]
*
\
reference_points
=
reference_points
.
clone
()
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
]
reference_points
[...,
1
:
2
]
=
reference_points
[...,
1
:
2
]
*
\
reference_points
[...,
0
:
1
]
=
reference_points
[...,
0
:
1
]
*
\
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
]
(
pc_range
[
3
]
-
pc_range
[
0
])
+
pc_range
[
0
]
reference_points
[...,
2
:
3
]
=
reference_points
[...,
2
:
3
]
*
\
reference_points
[...,
1
:
2
]
=
reference_points
[...,
1
:
2
]
*
\
(
pc_range
[
5
]
-
pc_range
[
2
])
+
pc_range
[
2
]
(
pc_range
[
4
]
-
pc_range
[
1
])
+
pc_range
[
1
]
reference_points
[...,
2
:
3
]
=
reference_points
[...,
2
:
3
]
*
\
reference_points
=
torch
.
cat
(
(
pc_range
[
5
]
-
pc_range
[
2
])
+
pc_range
[
2
]
(
reference_points
,
torch
.
ones_like
(
reference_points
[...,
:
1
])),
-
1
)
reference_points
=
torch
.
cat
(
reference_points
=
reference_points
.
permute
(
1
,
0
,
2
,
3
)
(
reference_points
,
torch
.
ones_like
(
reference_points
[...,
:
1
])),
-
1
)
D
,
B
,
num_query
=
reference_points
.
size
()[:
3
]
num_cam
=
lidar2img
.
size
(
1
)
reference_points
=
reference_points
.
permute
(
1
,
0
,
2
,
3
)
D
,
B
,
num_query
=
reference_points
.
size
()[:
3
]
reference_points
=
reference_points
.
view
(
num_cam
=
lidar2img
.
size
(
1
)
D
,
B
,
1
,
num_query
,
4
).
repeat
(
1
,
1
,
num_cam
,
1
,
1
).
unsqueeze
(
-
1
)
reference_points
=
reference_points
.
view
(
lidar2img
=
lidar2img
.
view
(
D
,
B
,
1
,
num_query
,
4
).
repeat
(
1
,
1
,
num_cam
,
1
,
1
).
unsqueeze
(
-
1
)
1
,
B
,
num_cam
,
1
,
4
,
4
).
repeat
(
D
,
1
,
1
,
num_query
,
1
,
1
)
lidar2img
=
lidar2img
.
view
(
reference_points_cam
=
torch
.
matmul
(
lidar2img
.
to
(
torch
.
float32
),
1
,
B
,
num_cam
,
1
,
4
,
4
).
repeat
(
D
,
1
,
1
,
num_query
,
1
,
1
)
reference_points
.
to
(
torch
.
float32
)).
squeeze
(
-
1
)
eps
=
1e-5
reference_points_cam
=
torch
.
matmul
(
lidar2img
.
to
(
torch
.
float32
),
reference_points
.
to
(
torch
.
float32
)).
squeeze
(
-
1
)
bev_mask
=
(
reference_points_cam
[...,
2
:
3
]
>
eps
)
eps
=
1e-5
reference_points_cam
=
reference_points_cam
[...,
0
:
2
]
/
torch
.
maximum
(
reference_points_cam
[...,
2
:
3
],
torch
.
ones_like
(
reference_points_cam
[...,
2
:
3
])
*
eps
)
bev_mask
=
(
reference_points_cam
[...,
2
:
3
]
>
eps
)
reference_points_cam
=
reference_points_cam
[...,
0
:
2
]
/
torch
.
maximum
(
reference_points_cam
[...,
0
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
1
]
reference_points_cam
[...,
2
:
3
],
torch
.
ones_like
(
reference_points_cam
[...,
2
:
3
])
*
eps
)
reference_points_cam
[...,
1
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
0
]
reference_points_cam
[...,
0
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
1
]
bev_mask
=
(
bev_mask
&
(
reference_points_cam
[...,
1
:
2
]
>
0.0
)
reference_points_cam
[...,
1
]
/=
img_metas
[
0
][
'img_shape'
][
0
][
0
]
&
(
reference_points_cam
[...,
1
:
2
]
<
1.0
)
&
(
reference_points_cam
[...,
0
:
1
]
<
1.0
)
bev_mask
=
(
bev_mask
&
(
reference_points_cam
[...,
1
:
2
]
>
0.0
)
&
(
reference_points_cam
[...,
0
:
1
]
>
0.0
))
&
(
reference_points_cam
[...,
1
:
2
]
<
1.0
)
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
&
(
reference_points_cam
[...,
0
:
1
]
<
1.0
)
bev_mask
=
torch
.
nan_to_num
(
bev_mask
)
&
(
reference_points_cam
[...,
0
:
1
]
>
0.0
))
else
:
if
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.8'
):
bev_mask
=
bev_mask
.
new_tensor
(
bev_mask
=
torch
.
nan_to_num
(
bev_mask
)
np
.
nan_to_num
(
bev_mask
.
cpu
().
numpy
()))
else
:
bev_mask
=
bev_mask
.
new_tensor
(
reference_points_cam
=
reference_points_cam
.
permute
(
2
,
1
,
3
,
0
,
4
)
np
.
nan_to_num
(
bev_mask
.
cpu
().
numpy
()))
bev_mask
=
bev_mask
.
permute
(
2
,
1
,
3
,
0
,
4
).
squeeze
(
-
1
)
reference_points_cam
=
reference_points_cam
.
permute
(
2
,
1
,
3
,
0
,
4
)
return
reference_points_cam
,
bev_mask
bev_mask
=
bev_mask
.
permute
(
2
,
1
,
3
,
0
,
4
).
squeeze
(
-
1
)
@
auto_fp16
()
return
reference_points_cam
,
bev_mask
def
forward
(
self
,
bev_query
,
@
auto_fp16
()
key
,
def
forward
(
self
,
value
,
bev_query
,
*
args
,
key
,
bev_h
=
None
,
value
,
bev_w
=
None
,
*
args
,
bev_pos
=
None
,
bev_h
=
None
,
spatial_shapes
=
None
,
bev_w
=
None
,
level_start_index
=
None
,
bev_pos
=
None
,
valid_ratios
=
None
,
spatial_shapes
=
None
,
prev_bev
=
None
,
level_start_index
=
None
,
shift
=
0.
,
valid_ratios
=
None
,
**
kwargs
):
prev_bev
=
None
,
"""Forward function for `TransformerDecoder`.
shift
=
0.
,
Args:
**
kwargs
):
bev_query (Tensor): Input BEV query with shape
"""Forward function for `TransformerDecoder`.
`(num_query, bs, embed_dims)`.
Args:
key & value (Tensor): Input multi-cameta features with shape
bev_query (Tensor): Input BEV query with shape
(num_cam, num_value, bs, embed_dims)
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
key & value (Tensor): Input multi-cameta features with shape
points of offset. has shape
(num_cam, num_value, bs, embed_dims)
(bs, num_query, 4) when as_two_stage,
reference_points (Tensor): The reference
otherwise has shape ((bs, num_query, 2).
points of offset. has shape
valid_ratios (Tensor): The radios of valid
(bs, num_query, 4) when as_two_stage,
points on the feature map, has shape
otherwise has shape ((bs, num_query, 2).
(bs, num_levels, 2)
valid_ratios (Tensor): The radios of valid
Returns:
points on the feature map, has shape
Tensor: Results with shape [1, num_query, bs, embed_dims] when
(bs, num_levels, 2)
return_intermediate is `False`, otherwise it has shape
Returns:
[num_layers, num_query, bs, embed_dims].
Tensor: Results with shape [1, num_query, bs, embed_dims] when
"""
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
output
=
bev_query
"""
intermediate
=
[]
output
=
bev_query
ref_3d
=
self
.
get_reference_points
(
intermediate
=
[]
bev_h
,
bev_w
,
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
],
self
.
num_points_in_pillar
,
dim
=
'3d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
ref_3d
=
self
.
get_reference_points
(
ref_2d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
self
.
pc_range
[
5
]
-
self
.
pc_range
[
2
],
self
.
num_points_in_pillar
,
dim
=
'3d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
bev_h
,
bev_w
,
dim
=
'2d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
ref_2d
=
self
.
get_reference_points
(
bev_h
,
bev_w
,
dim
=
'2d'
,
bs
=
bev_query
.
size
(
1
),
device
=
bev_query
.
device
,
dtype
=
bev_query
.
dtype
)
reference_points_cam
,
bev_mask
=
self
.
point_sampling
(
ref_3d
,
self
.
pc_range
,
kwargs
[
'img_metas'
])
reference_points_cam
,
bev_mask
=
self
.
point_sampling
(
ref_3d
,
self
.
pc_range
,
kwargs
[
'img_metas'
])
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d
=
ref_2d
.
clone
()
# .clone()
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d
+=
shift
[:,
None
,
None
,
:]
shift_ref_2d
=
ref_2d
.
clone
()
# .clone()
shift_ref_2d
+=
shift
[:,
None
,
None
,
:]
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query
=
bev_query
.
permute
(
1
,
0
,
2
)
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_pos
=
bev_pos
.
permute
(
1
,
0
,
2
)
bev_query
=
bev_query
.
permute
(
1
,
0
,
2
)
bs
,
len_bev
,
num_bev_level
,
_
=
ref_2d
.
shape
bev_pos
=
bev_pos
.
permute
(
1
,
0
,
2
)
if
prev_bev
is
not
None
:
bs
,
len_bev
,
num_bev_level
,
_
=
ref_2d
.
shape
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
if
prev_bev
is
not
None
:
prev_bev
=
torch
.
stack
(
prev_bev
=
prev_bev
.
permute
(
1
,
0
,
2
)
[
prev_bev
,
bev_query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
-
1
)
prev_bev
=
torch
.
stack
(
hybird_ref_2d
=
torch
.
stack
([
shift_ref_2d
,
ref_2d
],
1
).
reshape
(
[
prev_bev
,
bev_query
],
1
).
reshape
(
bs
*
2
,
len_bev
,
-
1
)
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
hybird_ref_2d
=
torch
.
stack
([
shift_ref_2d
,
ref_2d
],
1
).
reshape
(
else
:
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
hybird_ref_2d
=
torch
.
stack
([
ref_2d
,
ref_2d
],
1
).
reshape
(
else
:
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
hybird_ref_2d
=
torch
.
stack
([
ref_2d
,
ref_2d
],
1
).
reshape
(
bs
*
2
,
len_bev
,
num_bev_level
,
2
)
for
lid
,
layer
in
enumerate
(
self
.
layers
):
output
=
layer
(
for
lid
,
layer
in
enumerate
(
self
.
layers
):
bev_query
,
output
=
layer
(
key
,
bev_query
,
value
,
key
,
*
args
,
value
,
bev_pos
=
bev_pos
,
*
args
,
ref_2d
=
hybird_ref_2d
,
bev_pos
=
bev_pos
,
ref_3d
=
ref_3d
,
ref_2d
=
hybird_ref_2d
,
bev_h
=
bev_h
,
ref_3d
=
ref_3d
,
bev_w
=
bev_w
,
bev_h
=
bev_h
,
spatial_shapes
=
spatial_shapes
,
bev_w
=
bev_w
,
level_start_index
=
level_start_index
,
spatial_shapes
=
spatial_shapes
,
reference_points_cam
=
reference_points_cam
,
level_start_index
=
level_start_index
,
bev_mask
=
bev_mask
,
reference_points_cam
=
reference_points_cam
,
prev_bev
=
prev_bev
,
bev_mask
=
bev_mask
,
**
kwargs
)
prev_bev
=
prev_bev
,
**
kwargs
)
bev_query
=
output
if
self
.
return_intermediate
:
bev_query
=
output
intermediate
.
append
(
output
)
if
self
.
return_intermediate
:
intermediate
.
append
(
output
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
)
if
self
.
return_intermediate
:
return
torch
.
stack
(
intermediate
)
return
output
return
output
@
TRANSFORMER_LAYER
.
register_module
()
class
BEVFormerLayer
(
MyCustomBaseTransformerLayer
):
@
TRANSFORMER_LAYER
.
register_module
()
"""Implements decoder layer in DETR transformer.
class
BEVFormerLayer
(
MyCustomBaseTransformerLayer
):
Args:
"""Implements decoder layer in DETR transformer.
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Args:
Configs for self_attention or cross_attention, the order
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
should be consistent with it in `operation_order`. If it is
Configs for self_attention or cross_attention, the order
a dict, it would be expand to the number of attention in
should be consistent with it in `operation_order`. If it is
`operation_order`.
a dict, it would be expand to the number of attention in
feedforward_channels (int): The hidden dimension for FFNs.
`operation_order`.
ffn_dropout (float): Probability of an element to be zeroed
feedforward_channels (int): The hidden dimension for FFNs.
in ffn. Default 0.0.
ffn_dropout (float): Probability of an element to be zeroed
operation_order (tuple[str]): The execution order of operation
in ffn. Default 0.0.
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
operation_order (tuple[str]): The execution order of operation
Default:None
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
act_cfg (dict): The activation config for FFNs. Default: `LN`
Default:None
norm_cfg (dict): Config dict for normalization layer.
act_cfg (dict): The activation config for FFNs. Default: `LN`
Default: `LN`.
norm_cfg (dict): Config dict for normalization layer.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default: `LN`.
Default:2.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
"""
Default:2.
"""
def
__init__
(
self
,
attn_cfgs
,
def
__init__
(
self
,
ffn_cfgs
,
attn_cfgs
,
operation_order
=
None
,
ffn_cfgs
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
operation_order
=
None
,
norm_cfg
=
dict
(
type
=
'LN'
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
**
kwargs
):
norm_cfg
=
dict
(
type
=
'LN'
),
super
(
BEVFormerLayer
,
self
).
__init__
(
**
kwargs
):
attn_cfgs
=
attn_cfgs
,
super
(
BEVFormerLayer
,
self
).
__init__
(
ffn_cfgs
=
ffn_cfgs
,
attn_cfgs
=
attn_cfgs
,
operation_order
=
operation_order
,
ffn_cfgs
=
ffn_cfgs
,
act_cfg
=
act_cfg
,
operation_order
=
operation_order
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
**
kwargs
)
norm_cfg
=
norm_cfg
,
self
.
fp16_enabled
=
False
**
kwargs
)
assert
len
(
operation_order
)
==
6
self
.
fp16_enabled
=
False
assert
set
(
operation_order
)
==
set
(
assert
len
(
operation_order
)
==
6
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
assert
set
(
operation_order
)
==
set
(
[
'self_attn'
,
'norm'
,
'cross_attn'
,
'ffn'
])
def
forward
(
self
,
query
,
def
forward
(
self
,
key
=
None
,
query
,
value
=
None
,
key
=
None
,
bev_pos
=
None
,
value
=
None
,
query_pos
=
None
,
bev_pos
=
None
,
key_pos
=
None
,
query_pos
=
None
,
attn_masks
=
None
,
key_pos
=
None
,
query_key_padding_mask
=
None
,
attn_masks
=
None
,
key_padding_mask
=
None
,
query_key_padding_mask
=
None
,
ref_2d
=
None
,
key_padding_mask
=
None
,
ref_3d
=
None
,
ref_2d
=
None
,
bev_h
=
None
,
ref_3d
=
None
,
bev_w
=
None
,
bev_h
=
None
,
reference_points_cam
=
None
,
bev_w
=
None
,
mask
=
None
,
reference_points_cam
=
None
,
spatial_shapes
=
None
,
mask
=
None
,
level_start_index
=
None
,
spatial_shapes
=
None
,
prev_bev
=
None
,
level_start_index
=
None
,
**
kwargs
):
prev_bev
=
None
,
"""Forward function for `TransformerDecoderLayer`.
**
kwargs
):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
Args:
[num_queries, bs, embed_dims] if
query (Tensor): The input query with shape
self.batch_first is False, else
[num_queries, bs, embed_dims] if
[bs, num_queries embed_dims].
self.batch_first is False, else
key (Tensor): The key tensor with shape [num_keys, bs,
[bs, num_queries embed_dims].
embed_dims] if self.batch_first is False, else
key (Tensor): The key tensor with shape [num_keys, bs,
[bs, num_keys, embed_dims] .
embed_dims] if self.batch_first is False, else
value (Tensor): The value tensor with same shape as `key`.
[bs, num_keys, embed_dims] .
query_pos (Tensor): The positional encoding for `query`.
value (Tensor): The value tensor with same shape as `key`.
Default: None.
query_pos (Tensor): The positional encoding for `query`.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
attn_masks (List[Tensor] | None): 2D Tensor used in
Default: None.
calculation of corresponding attention. The length of
attn_masks (List[Tensor] | None): 2D Tensor used in
it should equal to the number of `attention` in
calculation of corresponding attention. The length of
`operation_order`. Default: None.
it should equal to the number of `attention` in
query_key_padding_mask (Tensor): ByteTensor for `query`, with
`operation_order`. Default: None.
shape [bs, num_queries]. Only used in `self_attn` layer.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
Defaults to None.
shape [bs, num_queries]. Only used in `self_attn` layer.
key_padding_mask (Tensor): ByteTensor for `query`, with
Defaults to None.
shape [bs, num_keys]. Default: None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
Returns:
"""
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index
=
0
attn_index
=
0
norm_index
=
0
ffn_index
=
0
attn_index
=
0
identity
=
query
ffn_index
=
0
if
attn_masks
is
None
:
identity
=
query
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
if
attn_masks
is
None
:
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
attn_masks
=
[
None
for
_
in
range
(
self
.
num_attn
)]
attn_masks
=
[
elif
isinstance
(
attn_masks
,
torch
.
Tensor
):
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
attn_masks
=
[
]
copy
.
deepcopy
(
attn_masks
)
for
_
in
range
(
self
.
num_attn
)
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
]
f
'
{
self
.
__class__
.
__name__
}
'
)
warnings
.
warn
(
f
'Use same attn_mask in all attentions in '
else
:
f
'
{
self
.
__class__
.
__name__
}
'
)
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
else
:
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
assert
len
(
attn_masks
)
==
self
.
num_attn
,
f
'The length of '
\
f
'to the number of attention in '
\
f
'attn_masks
{
len
(
attn_masks
)
}
must be equal '
\
f
'operation_order
{
self
.
num_attn
}
'
f
'to the number of attention in '
\
f
'operation_order
{
self
.
num_attn
}
'
for
layer
in
self
.
operation_order
:
# temporal self attention
for
layer
in
self
.
operation_order
:
if
layer
==
'self_attn'
:
# temporal self attention
if
layer
==
'self_attn'
:
query
=
self
.
attentions
[
attn_index
](
query
,
query
=
self
.
attentions
[
attn_index
](
prev_bev
,
query
,
prev_bev
,
prev_bev
,
identity
if
self
.
pre_norm
else
None
,
prev_bev
,
query_pos
=
bev_pos
,
identity
if
self
.
pre_norm
else
None
,
key_pos
=
bev_pos
,
query_pos
=
bev_pos
,
attn_mask
=
attn_masks
[
attn_index
],
key_pos
=
bev_pos
,
key_padding_mask
=
query_key_padding_mask
,
attn_mask
=
attn_masks
[
attn_index
],
reference_points
=
ref_2d
,
key_padding_mask
=
query_key_padding_mask
,
spatial_shapes
=
torch
.
tensor
(
reference_points
=
ref_2d
,
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
spatial_shapes
=
torch
.
tensor
(
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
[[
bev_h
,
bev_w
]],
device
=
query
.
device
),
**
kwargs
)
level_start_index
=
torch
.
tensor
([
0
],
device
=
query
.
device
),
attn_index
+=
1
**
kwargs
)
identity
=
query
attn_index
+=
1
identity
=
query
elif
layer
==
'norm'
:
query
=
self
.
norms
[
norm_index
](
query
)
elif
layer
==
'norm'
:
norm_index
+=
1
query
=
self
.
norms
[
norm_index
](
query
)
norm_index
+=
1
# spaital cross attention
elif
layer
==
'cross_attn'
:
# spaital cross attention
query
=
self
.
attentions
[
attn_index
](
elif
layer
==
'cross_attn'
:
query
,
query
=
self
.
attentions
[
attn_index
](
key
,
query
,
value
,
key
,
identity
if
self
.
pre_norm
else
None
,
value
,
query_pos
=
query_pos
,
identity
if
self
.
pre_norm
else
None
,
key_pos
=
key_pos
,
query_pos
=
query_pos
,
reference_points
=
ref_3d
,
key_pos
=
key_pos
,
reference_points_cam
=
reference_points_cam
,
reference_points
=
ref_3d
,
mask
=
mask
,
reference_points_cam
=
reference_points_cam
,
attn_mask
=
attn_masks
[
attn_index
],
mask
=
mask
,
key_padding_mask
=
key_padding_mask
,
attn_mask
=
attn_masks
[
attn_index
],
spatial_shapes
=
spatial_shapes
,
key_padding_mask
=
key_padding_mask
,
level_start_index
=
level_start_index
,
spatial_shapes
=
spatial_shapes
,
**
kwargs
)
level_start_index
=
level_start_index
,
attn_index
+=
1
**
kwargs
)
identity
=
query
attn_index
+=
1
identity
=
query
elif
layer
==
'ffn'
:
query
=
self
.
ffns
[
ffn_index
](
elif
layer
==
'ffn'
:
query
,
identity
if
self
.
pre_norm
else
None
)
query
=
self
.
ffns
[
ffn_index
](
ffn_index
+=
1
query
,
identity
if
self
.
pre_norm
else
None
)
ffn_index
+=
1
return
query
return
query
autonomous_driving/openlane-v2/plugin/mmdet3d/baseline/models/modules/multi_scale_deformable_attn_function.py
View file @
41b18fd8
# ---------------------------------------------
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# ---------------------------------------------
# Modified by Zhiqi Li
# Modified by Zhiqi Li
# ---------------------------------------------
# ---------------------------------------------
import
torch
import
torch
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
mmcv.utils
import
ext_loader
from
torch.autograd.function
import
Function
,
once_differentiable
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv.utils
import
ext_loader
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ms_deform_attn_backward'
,
'ms_deform_attn_forward'
])
class
MultiScaleDeformableAttnFunction_fp16
(
Function
):
class
MultiScaleDeformableAttnFunction_fp16
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
@
staticmethod
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
sampling_locations
,
attention_weights
,
im2col_step
):
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
"""GPU version of multi-scale deformable attention.
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
Args:
(bs, num_keys, mum_heads, embed_dims//num_heads)
value (Tensor): The value has shape
value_spatial_shapes (Tensor): Spatial shape of
(bs, num_keys, mum_heads, embed_dims//num_heads)
each feature map, has shape (num_levels, 2),
value_spatial_shapes (Tensor): Spatial shape of
last dimension 2 represent (h, w)
each feature map, has shape (num_levels, 2),
sampling_locations (Tensor): The location of sampling points,
last dimension 2 represent (h, w)
has shape
sampling_locations (Tensor): The location of sampling points,
(bs ,num_queries, num_heads, num_levels, num_points, 2),
has shape
the last dimension 2 represent (x, y).
(bs ,num_queries, num_heads, num_levels, num_points, 2),
attention_weights (Tensor): The weight of sampling points used
the last dimension 2 represent (x, y).
when calculate the attention, has shape
attention_weights (Tensor): The weight of sampling points used
(bs ,num_queries, num_heads, num_levels, num_points),
when calculate the attention, has shape
im2col_step (Tensor): The step used in image to column.
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
Returns:
"""
Tensor: has shape (bs, num_queries, embed_dims)
ctx
.
im2col_step
=
im2col_step
"""
output
=
ext_module
.
ms_deform_attn_forward
(
ctx
.
im2col_step
=
im2col_step
value
,
output
=
ext_module
.
ms_deform_attn_forward
(
value_spatial_shapes
,
value
,
value_level_start_index
,
value_spatial_shapes
,
sampling_locations
,
value_level_start_index
,
attention_weights
,
sampling_locations
,
im2col_step
=
ctx
.
im2col_step
)
attention_weights
,
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
im2col_step
=
ctx
.
im2col_step
)
value_level_start_index
,
sampling_locations
,
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
attention_weights
)
value_level_start_index
,
sampling_locations
,
return
output
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
staticmethod
@
custom_bwd
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
@
custom_bwd
"""GPU version of backward function.
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
Args:
of output tensor of forward.
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
Returns:
of input tensors in forward.
Tuple[Tensor]: Gradient
"""
of input tensors in forward.
value
,
value_spatial_shapes
,
value_level_start_index
,
\
"""
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
value
,
value_spatial_shapes
,
value_level_start_index
,
\
grad_value
=
torch
.
zeros_like
(
value
)
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_value
=
torch
.
zeros_like
(
value
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
ext_module
.
ms_deform_attn_backward
(
value_spatial_shapes
,
value
,
value_level_start_index
,
value_spatial_shapes
,
sampling_locations
,
value_level_start_index
,
attention_weights
,
sampling_locations
,
grad_output
.
contiguous
(),
attention_weights
,
grad_value
,
grad_output
.
contiguous
(),
grad_sampling_loc
,
grad_value
,
grad_attn_weight
,
grad_sampling_loc
,
im2col_step
=
ctx
.
im2col_step
)
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
class
MultiScaleDeformableAttnFunction_fp32
(
Function
):
class
MultiScaleDeformableAttnFunction_fp32
(
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
@
staticmethod
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
sampling_locations
,
attention_weights
,
im2col_step
):
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
"""GPU version of multi-scale deformable attention.
sampling_locations
,
attention_weights
,
im2col_step
):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
Args:
(bs, num_keys, mum_heads, embed_dims//num_heads)
value (Tensor): The value has shape
value_spatial_shapes (Tensor): Spatial shape of
(bs, num_keys, mum_heads, embed_dims//num_heads)
each feature map, has shape (num_levels, 2),
value_spatial_shapes (Tensor): Spatial shape of
last dimension 2 represent (h, w)
each feature map, has shape (num_levels, 2),
sampling_locations (Tensor): The location of sampling points,
last dimension 2 represent (h, w)
has shape
sampling_locations (Tensor): The location of sampling points,
(bs ,num_queries, num_heads, num_levels, num_points, 2),
has shape
the last dimension 2 represent (x, y).
(bs ,num_queries, num_heads, num_levels, num_points, 2),
attention_weights (Tensor): The weight of sampling points used
the last dimension 2 represent (x, y).
when calculate the attention, has shape
attention_weights (Tensor): The weight of sampling points used
(bs ,num_queries, num_heads, num_levels, num_points),
when calculate the attention, has shape
im2col_step (Tensor): The step used in image to column.
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
Returns:
"""
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx
.
im2col_step
=
im2col_step
output
=
ext_module
.
ms_deform_attn_forward
(
ctx
.
im2col_step
=
im2col_step
value
,
output
=
ext_module
.
ms_deform_attn_forward
(
value_spatial_shapes
,
value
,
value_level_start_index
,
value_spatial_shapes
,
sampling_locations
,
value_level_start_index
,
attention_weights
,
sampling_locations
,
im2col_step
=
ctx
.
im2col_step
)
attention_weights
,
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
im2col_step
=
ctx
.
im2col_step
)
value_level_start_index
,
sampling_locations
,
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
attention_weights
)
value_level_start_index
,
sampling_locations
,
return
output
attention_weights
)
return
output
@
staticmethod
@
once_differentiable
@
staticmethod
@
custom_bwd
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
@
custom_bwd
"""GPU version of backward function.
def
backward
(
ctx
,
grad_output
):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
Args:
of output tensor of forward.
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
Returns:
of input tensors in forward.
Tuple[Tensor]: Gradient
"""
of input tensors in forward.
value
,
value_spatial_shapes
,
value_level_start_index
,
\
"""
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
value
,
value_spatial_shapes
,
value_level_start_index
,
\
grad_value
=
torch
.
zeros_like
(
value
)
sampling_locations
,
attention_weights
=
ctx
.
saved_tensors
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_value
=
torch
.
zeros_like
(
value
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
grad_sampling_loc
=
torch
.
zeros_like
(
sampling_locations
)
grad_attn_weight
=
torch
.
zeros_like
(
attention_weights
)
ext_module
.
ms_deform_attn_backward
(
value
,
ext_module
.
ms_deform_attn_backward
(
value_spatial_shapes
,
value
,
value_level_start_index
,
value_spatial_shapes
,
sampling_locations
,
value_level_start_index
,
attention_weights
,
sampling_locations
,
grad_output
.
contiguous
(),
attention_weights
,
grad_value
,
grad_output
.
contiguous
(),
grad_sampling_loc
,
grad_value
,
grad_attn_weight
,
grad_sampling_loc
,
im2col_step
=
ctx
.
im2col_step
)
grad_attn_weight
,
im2col_step
=
ctx
.
im2col_step
)
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
return
grad_value
,
None
,
None
,
\
grad_sampling_loc
,
grad_attn_weight
,
None
Prev
1
…
7
8
9
10
11
12
13
14
15
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment