Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
lishj6
Flashocc
Commits
3b8d508a
Commit
3b8d508a
authored
Sep 05, 2025
by
lishj6
🏸
Browse files
init_0905
parent
e968ab0f
Pipeline
#2906
canceled with stages
Changes
156
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7643 additions
and
0 deletions
+7643
-0
projects/mmdet3d_plugin/models/backbones/resnet.py
projects/mmdet3d_plugin/models/backbones/resnet.py
+190
-0
projects/mmdet3d_plugin/models/backbones/swin.py
projects/mmdet3d_plugin/models/backbones/swin.py
+977
-0
projects/mmdet3d_plugin/models/dense_heads/__init__.py
projects/mmdet3d_plugin/models/dense_heads/__init__.py
+4
-0
projects/mmdet3d_plugin/models/dense_heads/bev_centerpoint_head.py
...mmdet3d_plugin/models/dense_heads/bev_centerpoint_head.py
+1764
-0
projects/mmdet3d_plugin/models/dense_heads/bev_occ_head.py
projects/mmdet3d_plugin/models/dense_heads/bev_occ_head.py
+405
-0
projects/mmdet3d_plugin/models/detectors/__init__.py
projects/mmdet3d_plugin/models/detectors/__init__.py
+11
-0
projects/mmdet3d_plugin/models/detectors/bevdepth.py
projects/mmdet3d_plugin/models/detectors/bevdepth.py
+258
-0
projects/mmdet3d_plugin/models/detectors/bevdepth4d.py
projects/mmdet3d_plugin/models/detectors/bevdepth4d.py
+58
-0
projects/mmdet3d_plugin/models/detectors/bevdet.py
projects/mmdet3d_plugin/models/detectors/bevdet.py
+252
-0
projects/mmdet3d_plugin/models/detectors/bevdet4d.py
projects/mmdet3d_plugin/models/detectors/bevdet4d.py
+387
-0
projects/mmdet3d_plugin/models/detectors/bevdet_occ.py
projects/mmdet3d_plugin/models/detectors/bevdet_occ.py
+1481
-0
projects/mmdet3d_plugin/models/detectors/bevstereo4d.py
projects/mmdet3d_plugin/models/detectors/bevstereo4d.py
+287
-0
projects/mmdet3d_plugin/models/losses/__init__.py
projects/mmdet3d_plugin/models/losses/__init__.py
+4
-0
projects/mmdet3d_plugin/models/losses/cross_entropy_loss.py
projects/mmdet3d_plugin/models/losses/cross_entropy_loss.py
+302
-0
projects/mmdet3d_plugin/models/losses/focal_loss.py
projects/mmdet3d_plugin/models/losses/focal_loss.py
+265
-0
projects/mmdet3d_plugin/models/losses/lovasz_softmax.py
projects/mmdet3d_plugin/models/losses/lovasz_softmax.py
+329
-0
projects/mmdet3d_plugin/models/losses/semkitti_loss.py
projects/mmdet3d_plugin/models/losses/semkitti_loss.py
+185
-0
projects/mmdet3d_plugin/models/model_utils/__init__.py
projects/mmdet3d_plugin/models/model_utils/__init__.py
+3
-0
projects/mmdet3d_plugin/models/model_utils/depthnet.py
projects/mmdet3d_plugin/models/model_utils/depthnet.py
+476
-0
projects/mmdet3d_plugin/models/necks/__init__.py
projects/mmdet3d_plugin/models/necks/__init__.py
+5
-0
No files found.
projects/mmdet3d_plugin/models/backbones/resnet.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch.utils.checkpoint
as
checkpoint
from
torch
import
nn
import
torch
from
mmcv.cnn.bricks.conv_module
import
ConvModule
from
mmdet.models.backbones.resnet
import
BasicBlock
,
Bottleneck
from
mmdet3d.models
import
BACKBONES
@
BACKBONES
.
register_module
()
class
CustomResNet
(
nn
.
Module
):
def
__init__
(
self
,
numC_input
,
num_layer
=
[
2
,
2
,
2
],
num_channels
=
None
,
stride
=
[
2
,
2
,
2
],
backbone_output_ids
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
with_cp
=
False
,
block_type
=
'Basic'
,
):
super
(
CustomResNet
,
self
).
__init__
()
# build backbone
assert
len
(
num_layer
)
==
len
(
stride
)
num_channels
=
[
numC_input
*
2
**
(
i
+
1
)
for
i
in
range
(
len
(
num_layer
))]
\
if
num_channels
is
None
else
num_channels
self
.
backbone_output_ids
=
range
(
len
(
num_layer
))
\
if
backbone_output_ids
is
None
else
backbone_output_ids
layers
=
[]
if
block_type
==
'BottleNeck'
:
curr_numC
=
numC_input
for
i
in
range
(
len
(
num_layer
)):
# 在第一个block中对输入进行downsample
layer
=
[
Bottleneck
(
inplanes
=
curr_numC
,
planes
=
num_channels
[
i
]
//
4
,
stride
=
stride
[
i
],
downsample
=
nn
.
Conv2d
(
curr_numC
,
num_channels
[
i
],
3
,
stride
[
i
],
1
),
norm_cfg
=
norm_cfg
)]
curr_numC
=
num_channels
[
i
]
layer
.
extend
([
Bottleneck
(
inplanes
=
curr_numC
,
planes
=
num_channels
[
i
]
//
4
,
stride
=
1
,
downsample
=
None
,
norm_cfg
=
norm_cfg
)
for
_
in
range
(
num_layer
[
i
]
-
1
)])
layers
.
append
(
nn
.
Sequential
(
*
layer
))
elif
block_type
==
'Basic'
:
curr_numC
=
numC_input
for
i
in
range
(
len
(
num_layer
)):
# 在第一个block中对输入进行downsample
layer
=
[
BasicBlock
(
inplanes
=
curr_numC
,
planes
=
num_channels
[
i
],
stride
=
stride
[
i
],
downsample
=
nn
.
Conv2d
(
curr_numC
,
num_channels
[
i
],
3
,
stride
[
i
],
1
),
norm_cfg
=
norm_cfg
)]
curr_numC
=
num_channels
[
i
]
layer
.
extend
([
BasicBlock
(
inplanes
=
curr_numC
,
planes
=
num_channels
[
i
],
stride
=
1
,
downsample
=
None
,
norm_cfg
=
norm_cfg
)
for
_
in
range
(
num_layer
[
i
]
-
1
)])
layers
.
append
(
nn
.
Sequential
(
*
layer
))
else
:
assert
False
self
.
layers
=
nn
.
Sequential
(
*
layers
)
self
.
with_cp
=
with_cp
#@torch.compile
def
forward
(
self
,
x
):
"""
Args:
x: (B, C=64, Dy, Dx)
Returns:
feats: List[
(B, 2*C, Dy/2, Dx/2),
(B, 4*C, Dy/4, Dx/4),
(B, 8*C, Dy/8, Dx/8),
]
"""
feats
=
[]
x_tmp
=
x
for
lid
,
layer
in
enumerate
(
self
.
layers
):
if
self
.
with_cp
:
x_tmp
=
checkpoint
.
checkpoint
(
layer
,
x_tmp
)
else
:
x_tmp
=
layer
(
x_tmp
)
if
lid
in
self
.
backbone_output_ids
:
feats
.
append
(
x_tmp
)
return
feats
class
BasicBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
,
stride
=
1
,
downsample
=
None
):
super
(
BasicBlock3D
,
self
).
__init__
()
self
.
conv1
=
ConvModule
(
channels_in
,
channels_out
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
dict
(
type
=
'BN3d'
,
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
))
self
.
conv2
=
ConvModule
(
channels_out
,
channels_out
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
dict
(
type
=
'BN3d'
,
),
act_cfg
=
None
)
self
.
downsample
=
downsample
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
else
:
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
x
+
identity
return
self
.
relu
(
x
)
@
BACKBONES
.
register_module
()
class
CustomResNet3D
(
nn
.
Module
):
def
__init__
(
self
,
numC_input
,
num_layer
=
[
2
,
2
,
2
],
num_channels
=
None
,
stride
=
[
2
,
2
,
2
],
backbone_output_ids
=
None
,
with_cp
=
False
,
):
super
(
CustomResNet3D
,
self
).
__init__
()
# build backbone
assert
len
(
num_layer
)
==
len
(
stride
)
num_channels
=
[
numC_input
*
2
**
(
i
+
1
)
for
i
in
range
(
len
(
num_layer
))]
\
if
num_channels
is
None
else
num_channels
self
.
backbone_output_ids
=
range
(
len
(
num_layer
))
\
if
backbone_output_ids
is
None
else
backbone_output_ids
layers
=
[]
curr_numC
=
numC_input
for
i
in
range
(
len
(
num_layer
)):
layer
=
[
BasicBlock3D
(
curr_numC
,
num_channels
[
i
],
stride
=
stride
[
i
],
downsample
=
ConvModule
(
curr_numC
,
num_channels
[
i
],
kernel_size
=
3
,
stride
=
stride
[
i
],
padding
=
1
,
bias
=
False
,
conv_cfg
=
dict
(
type
=
'Conv3d'
),
norm_cfg
=
dict
(
type
=
'BN3d'
,
),
act_cfg
=
None
))
]
curr_numC
=
num_channels
[
i
]
layer
.
extend
([
BasicBlock3D
(
curr_numC
,
curr_numC
)
for
_
in
range
(
num_layer
[
i
]
-
1
)
])
layers
.
append
(
nn
.
Sequential
(
*
layer
))
self
.
layers
=
nn
.
Sequential
(
*
layers
)
self
.
with_cp
=
with_cp
def
forward
(
self
,
x
):
"""
Args:
x: (B, C, Dz, Dy, Dx)
Returns:
feats: List[
(B, C, Dz, Dy, Dx),
(B, 2C, Dz/2, Dy/2, Dx/2),
(B, 4C, Dz/4, Dy/4, Dx/4),
]
"""
feats
=
[]
x_tmp
=
x
for
lid
,
layer
in
enumerate
(
self
.
layers
):
if
self
.
with_cp
:
x_tmp
=
checkpoint
.
checkpoint
(
layer
,
x_tmp
)
else
:
x_tmp
=
layer
(
x_tmp
)
if
lid
in
self
.
backbone_output_ids
:
feats
.
append
(
x_tmp
)
return
feats
projects/mmdet3d_plugin/models/backbones/swin.py
0 → 100644
View file @
3b8d508a
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
copy
import
deepcopy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
build_norm_layer
,
trunc_normal_init
,
build_conv_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
build_dropout
from
mmcv.cnn.utils.weight_init
import
constant_init
from
mmcv.runner
import
_load_checkpoint
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
torch.nn.modules.linear
import
Linear
from
torch.nn.modules.normalization
import
LayerNorm
import
torch.utils.checkpoint
as
checkpoint
from
mmseg.ops
import
resize
from
mmdet3d.utils
import
get_root_logger
from
mmdet3d.models.builder
import
BACKBONES
from
mmcv.cnn.bricks.registry
import
ATTENTION
from
torch.nn.modules.utils
import
_pair
as
to_2tuple
from
collections
import
OrderedDict
def
swin_convert
(
ckpt
):
new_ckpt
=
OrderedDict
()
def
correct_unfold_reduction_order
(
x
):
out_channel
,
in_channel
=
x
.
shape
x
=
x
.
reshape
(
out_channel
,
4
,
in_channel
//
4
)
x
=
x
[:,
[
0
,
2
,
1
,
3
],
:].
transpose
(
1
,
2
).
reshape
(
out_channel
,
in_channel
)
return
x
def
correct_unfold_norm_order
(
x
):
in_channel
=
x
.
shape
[
0
]
x
=
x
.
reshape
(
4
,
in_channel
//
4
)
x
=
x
[[
0
,
2
,
1
,
3
],
:].
transpose
(
0
,
1
).
reshape
(
in_channel
)
return
x
for
k
,
v
in
ckpt
.
items
():
if
k
.
startswith
(
'head'
):
continue
elif
k
.
startswith
(
'layers'
):
new_v
=
v
if
'attn.'
in
k
:
new_k
=
k
.
replace
(
'attn.'
,
'attn.w_msa.'
)
elif
'mlp.'
in
k
:
if
'mlp.fc1.'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc1.'
,
'ffn.layers.0.0.'
)
elif
'mlp.fc2.'
in
k
:
new_k
=
k
.
replace
(
'mlp.fc2.'
,
'ffn.layers.1.'
)
else
:
new_k
=
k
.
replace
(
'mlp.'
,
'ffn.'
)
elif
'downsample'
in
k
:
new_k
=
k
if
'reduction.'
in
k
:
new_v
=
correct_unfold_reduction_order
(
v
)
elif
'norm.'
in
k
:
new_v
=
correct_unfold_norm_order
(
v
)
else
:
new_k
=
k
new_k
=
new_k
.
replace
(
'layers'
,
'stages'
,
1
)
elif
k
.
startswith
(
'patch_embed'
):
new_v
=
v
if
'proj'
in
k
:
new_k
=
k
.
replace
(
'proj'
,
'projection'
)
else
:
new_k
=
k
else
:
new_v
=
v
new_k
=
k
new_ckpt
[
new_k
]
=
new_v
return
new_ckpt
# Modified from Pytorch-Image-Models
class
PatchEmbed
(
BaseModule
):
"""Image to Patch Embedding V2.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type
selection. Default: None.
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size).
padding (int): The padding length of embedding conv. Default: 0.
dilation (int): The dilation rate of embedding conv. Default: 1.
pad_to_patch_size (bool, optional): Whether to pad feature map shape
to multiple patch size. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
in_channels
=
3
,
embed_dims
=
768
,
conv_type
=
None
,
kernel_size
=
16
,
stride
=
16
,
padding
=
0
,
dilation
=
1
,
pad_to_patch_size
=
True
,
norm_cfg
=
None
,
init_cfg
=
None
):
super
(
PatchEmbed
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
init_cfg
=
init_cfg
if
stride
is
None
:
stride
=
kernel_size
self
.
pad_to_patch_size
=
pad_to_patch_size
# The default setting of patch size is equal to kernel size.
patch_size
=
kernel_size
if
isinstance
(
patch_size
,
int
):
patch_size
=
to_2tuple
(
patch_size
)
elif
isinstance
(
patch_size
,
tuple
):
if
len
(
patch_size
)
==
1
:
patch_size
=
to_2tuple
(
patch_size
[
0
])
assert
len
(
patch_size
)
==
2
,
\
f
'The size of patch should have length 1 or 2, '
\
f
'but got
{
len
(
patch_size
)
}
'
self
.
patch_size
=
patch_size
# Use conv layer to embed
conv_type
=
conv_type
or
'Conv2d'
self
.
projection
=
build_conv_layer
(
dict
(
type
=
conv_type
),
in_channels
=
in_channels
,
out_channels
=
embed_dims
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
if
norm_cfg
is
not
None
:
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
H
,
W
=
x
.
shape
[
2
],
x
.
shape
[
3
]
# TODO: Process overlapping op
if
self
.
pad_to_patch_size
:
# Modify H, W to multiple of patch size.
if
H
%
self
.
patch_size
[
0
]
!=
0
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
self
.
patch_size
[
0
]
-
H
%
self
.
patch_size
[
0
]))
if
W
%
self
.
patch_size
[
1
]
!=
0
:
x
=
F
.
pad
(
x
,
(
0
,
self
.
patch_size
[
1
]
-
W
%
self
.
patch_size
[
1
],
0
,
0
))
x
=
self
.
projection
(
x
)
self
.
DH
,
self
.
DW
=
x
.
shape
[
2
],
x
.
shape
[
3
]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
class
PatchMerging
(
BaseModule
):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
2
,
bias
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
stride
=
stride
self
.
sampler
=
nn
.
Unfold
(
kernel_size
=
stride
,
dilation
=
1
,
padding
=
0
,
stride
=
stride
)
sample_dim
=
stride
**
2
*
in_channels
if
norm_cfg
is
not
None
:
self
.
norm
=
build_norm_layer
(
norm_cfg
,
sample_dim
)[
1
]
else
:
self
.
norm
=
None
self
.
reduction
=
nn
.
Linear
(
sample_dim
,
out_channels
,
bias
=
bias
)
def
forward
(
self
,
x
,
hw_shape
):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B
,
L
,
C
=
x
.
shape
H
,
W
=
hw_shape
assert
L
==
H
*
W
,
'input feature has wrong size'
x
=
x
.
view
(
B
,
H
,
W
,
C
).
permute
([
0
,
3
,
1
,
2
])
# B, C, H, W
# stride is fixed to be equal to kernel_size.
if
(
H
%
self
.
stride
!=
0
)
or
(
W
%
self
.
stride
!=
0
):
x
=
F
.
pad
(
x
,
(
0
,
W
%
self
.
stride
,
0
,
H
%
self
.
stride
))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x
=
self
.
sampler
(
x
)
# B, 4*C, H/2*W/2
x
=
x
.
transpose
(
1
,
2
)
# B, H/2*W/2, 4*C
x
=
self
.
norm
(
x
)
if
self
.
norm
else
x
x
=
self
.
reduction
(
x
)
down_hw_shape
=
(
H
+
1
)
//
2
,
(
W
+
1
)
//
2
return
x
,
down_hw_shape
@
ATTENTION
.
register_module
()
class
WindowMSA
(
BaseModule
):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop_rate
=
0.
,
proj_drop_rate
=
0.
,
init_cfg
=
None
):
super
().
__init__
()
self
.
embed_dims
=
embed_dims
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_embed_dims
=
embed_dims
//
num_heads
self
.
scale
=
qk_scale
or
head_embed_dims
**-
0.5
self
.
init_cfg
=
init_cfg
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh
,
Ww
=
self
.
window_size
rel_index_coords
=
self
.
double_step_seq
(
2
*
Ww
-
1
,
Wh
,
1
,
Ww
)
rel_position_index
=
rel_index_coords
+
rel_index_coords
.
T
rel_position_index
=
rel_position_index
.
flip
(
1
).
contiguous
()
self
.
register_buffer
(
'relative_position_index'
,
rel_position_index
)
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop_rate
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop_rate
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
init_weights
(
self
):
trunc_normal_init
(
self
.
relative_position_bias_table
,
std
=
0.02
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
@
staticmethod
def
double_step_seq
(
step1
,
len1
,
step2
,
len2
):
seq1
=
torch
.
arange
(
0
,
step1
*
len1
,
step1
)
seq2
=
torch
.
arange
(
0
,
step2
*
len2
,
step2
)
return
(
seq1
[:,
None
]
+
seq2
[
None
,
:]).
reshape
(
1
,
-
1
)
@
ATTENTION
.
register_module
()
class
ShiftWindowMSA
(
BaseModule
):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
,
shift_size
=
0
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop_rate
=
0
,
proj_drop_rate
=
0
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
0.
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
assert
0
<=
self
.
shift_size
<
self
.
window_size
self
.
w_msa
=
WindowMSA
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
window_size
=
to_2tuple
(
window_size
),
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop_rate
=
attn_drop_rate
,
proj_drop_rate
=
proj_drop_rate
,
init_cfg
=
None
)
self
.
drop
=
build_dropout
(
dropout_layer
)
def
forward
(
self
,
query
,
hw_shape
):
B
,
L
,
C
=
query
.
shape
H
,
W
=
hw_shape
assert
L
==
H
*
W
,
'input feature has wrong size'
query
=
query
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of window size
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
query
=
F
.
pad
(
query
,
(
0
,
0
,
0
,
pad_r
,
0
,
pad_b
))
H_pad
,
W_pad
=
query
.
shape
[
1
],
query
.
shape
[
2
]
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_query
=
torch
.
roll
(
query
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
# calculate attention mask for SW-MSA
img_mask
=
torch
.
zeros
((
1
,
H_pad
,
W_pad
,
1
),
device
=
query
.
device
)
# 1 H W 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
# nW, window_size, window_size, 1
mask_windows
=
self
.
window_partition
(
img_mask
)
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
else
:
shifted_query
=
query
attn_mask
=
None
# nW*B, window_size, window_size, C
query_windows
=
self
.
window_partition
(
shifted_query
)
# nW*B, window_size*window_size, C
query_windows
=
query_windows
.
view
(
-
1
,
self
.
window_size
**
2
,
C
)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows
=
self
.
w_msa
(
query_windows
,
mask
=
attn_mask
)
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
# B H' W' C
shifted_x
=
self
.
window_reverse
(
attn_windows
,
H_pad
,
W_pad
)
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
if
pad_r
>
0
or
pad_b
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
x
=
self
.
drop
(
x
)
return
x
def
window_reverse
(
self
,
windows
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size
=
self
.
window_size
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
def
window_partition
(
self
,
x
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
window_size
=
self
.
window_size
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
()
windows
=
windows
.
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
class
SwinBlock
(
BaseModule
):
""""
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
window_size
=
7
,
shift
=
False
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
SwinBlock
,
self
).
__init__
()
self
.
init_cfg
=
init_cfg
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
ShiftWindowMSA
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
window_size
//
2
if
shift
else
0
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop_rate
=
attn_drop_rate
,
proj_drop_rate
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
init_cfg
=
None
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
2
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
add_identity
=
True
,
init_cfg
=
None
)
def
forward
(
self
,
x
,
hw_shape
):
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
,
hw_shape
)
x
=
x
+
identity
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
ffn
(
x
,
identity
=
identity
)
return
x
class
SwinBlockSequence
(
BaseModule
):
"""Implements one stage in Swin Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
depth
,
window_size
=
7
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
downsample
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
,
with_cp
=
True
):
super
().
__init__
()
self
.
init_cfg
=
init_cfg
drop_path_rate
=
drop_path_rate
if
isinstance
(
drop_path_rate
,
list
)
else
[
deepcopy
(
drop_path_rate
)
for
_
in
range
(
depth
)]
self
.
blocks
=
ModuleList
()
for
i
in
range
(
depth
):
block
=
SwinBlock
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
feedforward_channels
=
feedforward_channels
,
window_size
=
window_size
,
shift
=
False
if
i
%
2
==
0
else
True
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
drop_path_rate
[
i
],
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
init_cfg
=
None
)
self
.
blocks
.
append
(
block
)
self
.
downsample
=
downsample
self
.
with_cp
=
with_cp
def
forward
(
self
,
x
,
hw_shape
):
for
block
in
self
.
blocks
:
if
self
.
with_cp
:
x
=
checkpoint
.
checkpoint
(
block
,
x
,
hw_shape
)
else
:
x
=
block
(
x
,
hw_shape
)
if
self
.
downsample
:
x_down
,
down_hw_shape
=
self
.
downsample
(
x
,
hw_shape
)
return
x_down
,
down_hw_shape
,
x
,
hw_shape
else
:
return
x
,
hw_shape
,
x
,
hw_shape
@
BACKBONES
.
register_module
()
class
SwinTransformer
(
BaseModule
):
""" Swin Transformer
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/abs/2103.14030
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): The num of input channels.
Defaults: 3.
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
each Swin Transformer stage. (In swin, we set kernel size equal to
stride.) Default: (4, 2, 2, 2).
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
value. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
patch_norm (bool): If add a norm layer for patch embed and patch
merging. Default: True.
drop_rate (float): Dropout rate. Defaults: 0.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: False.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
pretrain_img_size
=
224
,
in_channels
=
3
,
embed_dims
=
96
,
patch_size
=
4
,
window_size
=
7
,
mlp_ratio
=
4
,
depths
=
(
2
,
2
,
6
,
2
),
num_heads
=
(
3
,
6
,
12
,
24
),
strides
=
(
4
,
2
,
2
,
2
),
out_indices
=
(
0
,
1
,
2
,
3
),
qkv_bias
=
True
,
qk_scale
=
None
,
patch_norm
=
True
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
use_abs_pos_embed
=
False
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
pretrain_style
=
'official'
,
pretrained
=
None
,
init_cfg
=
None
,
with_cp
=
True
,
return_stereo_feat
=
False
,
output_missing_index_as_none
=
False
,
frozen_stages
=-
1
):
super
(
SwinTransformer
,
self
).
__init__
()
if
isinstance
(
pretrain_img_size
,
int
):
pretrain_img_size
=
to_2tuple
(
pretrain_img_size
)
elif
isinstance
(
pretrain_img_size
,
tuple
):
if
len
(
pretrain_img_size
)
==
1
:
pretrain_img_size
=
to_2tuple
(
pretrain_img_size
[
0
])
assert
len
(
pretrain_img_size
)
==
2
,
\
f
'The size of image should have length 1 or 2, '
\
f
'but got
{
len
(
pretrain_img_size
)
}
'
assert
pretrain_style
in
[
'official'
,
'mmcls'
],
'We only support load '
'official ckpt and mmcls ckpt.'
if
isinstance
(
pretrained
,
str
)
or
pretrained
is
None
:
warnings
.
warn
(
'DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead'
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
num_layers
=
len
(
depths
)
self
.
out_indices
=
out_indices
self
.
use_abs_pos_embed
=
use_abs_pos_embed
self
.
pretrain_style
=
pretrain_style
self
.
pretrained
=
pretrained
self
.
init_cfg
=
init_cfg
self
.
frozen_stages
=
frozen_stages
assert
strides
[
0
]
==
patch_size
,
'Use non-overlapping patch embed.'
self
.
patch_embed
=
PatchEmbed
(
in_channels
=
in_channels
,
embed_dims
=
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
strides
[
0
],
pad_to_patch_size
=
True
,
norm_cfg
=
norm_cfg
if
patch_norm
else
None
,
init_cfg
=
None
)
if
self
.
use_abs_pos_embed
:
patch_row
=
pretrain_img_size
[
0
]
//
patch_size
patch_col
=
pretrain_img_size
[
1
]
//
patch_size
num_patches
=
patch_row
*
patch_col
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
num_patches
,
embed_dims
)))
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
total_depth
=
sum
(
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
self
.
stages
=
ModuleList
()
in_channels
=
embed_dims
for
i
in
range
(
num_layers
):
if
i
<
num_layers
-
1
:
downsample
=
PatchMerging
(
in_channels
=
in_channels
,
out_channels
=
2
*
in_channels
,
stride
=
strides
[
i
+
1
],
norm_cfg
=
norm_cfg
if
patch_norm
else
None
,
init_cfg
=
None
)
else
:
downsample
=
None
stage
=
SwinBlockSequence
(
embed_dims
=
in_channels
,
num_heads
=
num_heads
[
i
],
feedforward_channels
=
mlp_ratio
*
in_channels
,
depth
=
depths
[
i
],
window_size
=
window_size
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
dpr
[:
depths
[
i
]],
downsample
=
downsample
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
init_cfg
=
None
,
with_cp
=
with_cp
)
self
.
stages
.
append
(
stage
)
dpr
=
dpr
[
depths
[
i
]:]
if
downsample
:
in_channels
=
downsample
.
out_channels
self
.
num_features
=
[
int
(
embed_dims
*
2
**
i
)
for
i
in
range
(
num_layers
)]
# Add a norm layer for each output
for
i
in
out_indices
:
layer
=
build_norm_layer
(
norm_cfg
,
self
.
num_features
[
i
])[
1
]
layer_name
=
f
'norm
{
i
}
'
self
.
add_module
(
layer_name
,
layer
)
self
.
output_missing_index_as_none
=
output_missing_index_as_none
self
.
_freeze_stages
()
self
.
return_stereo_feat
=
return_stereo_feat
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
if
self
.
frozen_stages
>=
1
and
self
.
use_abs_pos_embed
:
self
.
absolute_pos_embed
.
requires_grad
=
False
if
self
.
frozen_stages
>=
2
:
self
.
drop_after_pos
.
eval
()
for
i
in
range
(
0
,
self
.
frozen_stages
-
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
init_weights
(
self
):
if
self
.
pretrained
is
None
:
super
().
init_weights
()
if
self
.
use_abs_pos_embed
:
trunc_normal_init
(
self
.
absolute_pos_embed
,
std
=
0.02
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Linear
):
trunc_normal_init
(
m
.
weight
,
std
=
.
02
)
if
m
.
bias
is
not
None
:
constant_init
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
LayerNorm
):
constant_init
(
m
.
bias
,
0
)
constant_init
(
m
.
weight
,
1.0
)
elif
isinstance
(
self
.
pretrained
,
str
):
logger
=
get_root_logger
()
ckpt
=
_load_checkpoint
(
self
.
pretrained
,
logger
=
logger
,
map_location
=
'cpu'
)
if
'state_dict'
in
ckpt
:
state_dict
=
ckpt
[
'state_dict'
]
elif
'model'
in
ckpt
:
state_dict
=
ckpt
[
'model'
]
else
:
state_dict
=
ckpt
if
self
.
pretrain_style
==
'official'
:
state_dict
=
swin_convert
(
state_dict
)
# strip prefix of state_dict
if
list
(
state_dict
.
keys
())[
0
].
startswith
(
'module.'
):
state_dict
=
{
k
[
7
:]:
v
for
k
,
v
in
state_dict
.
items
()}
# if list(state_dict.keys())[0].startswith('backbone.'):
# state_dict = {k[9:]: v for k, v in state_dict.items()}
# reshape absolute position embedding
if
state_dict
.
get
(
'absolute_pos_embed'
)
is
not
None
:
absolute_pos_embed
=
state_dict
[
'absolute_pos_embed'
]
N1
,
L
,
C1
=
absolute_pos_embed
.
size
()
N2
,
C2
,
H
,
W
=
self
.
absolute_pos_embed
.
size
()
if
N1
!=
N2
or
C1
!=
C2
or
L
!=
H
*
W
:
logger
.
warning
(
'Error in loading absolute_pos_embed, pass'
)
else
:
state_dict
[
'absolute_pos_embed'
]
=
absolute_pos_embed
.
view
(
N2
,
H
,
W
,
C2
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# interpolate position bias table if needed
relative_position_bias_table_keys
=
[
k
for
k
in
state_dict
.
keys
()
if
'relative_position_bias_table'
in
k
]
for
table_key
in
relative_position_bias_table_keys
:
table_pretrained
=
state_dict
[
table_key
]
table_current
=
self
.
state_dict
()[
table_key
]
L1
,
nH1
=
table_pretrained
.
size
()
L2
,
nH2
=
table_current
.
size
()
if
nH1
!=
nH2
:
logger
.
warning
(
f
'Error in loading
{
table_key
}
, pass'
)
else
:
if
L1
!=
L2
:
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
table_pretrained_resized
=
resize
(
table_pretrained
.
permute
(
1
,
0
).
reshape
(
1
,
nH1
,
S1
,
S1
),
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
state_dict
[
table_key
]
=
table_pretrained_resized
.
view
(
nH2
,
L2
).
permute
(
1
,
0
).
contiguous
()
# load state_dict
self
.
load_state_dict
(
state_dict
,
False
)
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
hw_shape
=
(
self
.
patch_embed
.
DH
,
self
.
patch_embed
.
DW
)
if
self
.
use_abs_pos_embed
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
,
hw_shape
,
out
,
out_hw_shape
=
stage
(
x
,
hw_shape
)
if
i
==
0
and
self
.
return_stereo_feat
:
out
=
out
.
view
(
-
1
,
*
out_hw_shape
,
self
.
num_features
[
i
]).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
out
=
norm_layer
(
out
)
out
=
out
.
view
(
-
1
,
*
out_hw_shape
,
self
.
num_features
[
i
]).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
elif
self
.
output_missing_index_as_none
:
outs
.
append
(
None
)
return
outs
def
train
(
self
,
mode
=
True
):
"""Convert the model into training mode while keep normalization layer
freezed."""
super
(
SwinTransformer
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
\ No newline at end of file
projects/mmdet3d_plugin/models/dense_heads/__init__.py
0 → 100644
View file @
3b8d508a
from
.bev_centerpoint_head
import
BEV_CenterHead
,
Centerness_Head
from
.bev_occ_head
import
BEVOCCHead2D
,
BEVOCCHead3D
,
BEVOCCHead2D_V2
__all__
=
[
'Centerness_Head'
,
'BEV_CenterHead'
,
'BEVOCCHead2D'
,
'BEVOCCHead3D'
,
'BEVOCCHead2D_V2'
]
\ No newline at end of file
projects/mmdet3d_plugin/models/dense_heads/bev_centerpoint_head.py
0 → 100644
View file @
3b8d508a
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
torch
from
mmcv.cnn
import
ConvModule
,
build_conv_layer
from
mmcv.runner
import
BaseModule
from
torch
import
nn
from
mmdet3d.core
import
(
circle_nms
,
draw_heatmap_gaussian
,
gaussian_radius
,
xywhr2xyxyr
)
from
...core.post_processing
import
nms_bev
from
mmdet3d.models
import
builder
from
mmdet3d.models.utils
import
clip_sigmoid
from
mmdet.core
import
build_bbox_coder
,
multi_apply
,
reduce_mean
from
mmdet3d.models.builder
import
HEADS
,
build_loss
@
HEADS
.
register_module
(
force
=
True
)
class
SeparateHead
(
BaseModule
):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv layer.
Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def
__init__
(
self
,
in_channels
,
heads
,
head_conv
=
64
,
final_kernel
=
1
,
init_bias
=-
2.19
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
dict
(
type
=
'BN2d'
),
bias
=
'auto'
,
init_cfg
=
None
,
**
kwargs
):
assert
init_cfg
is
None
,
'To prevent abnormal initialization '
\
'behavior, init_cfg is not allowed to be set'
super
(
SeparateHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
heads
=
heads
self
.
init_bias
=
init_bias
for
head
in
self
.
heads
:
# 该head的输出通道和卷积数量.
classes
,
num_conv
=
self
.
heads
[
head
]
conv_layers
=
[]
c_in
=
in_channels
for
i
in
range
(
num_conv
-
1
):
conv_layers
.
append
(
ConvModule
(
c_in
,
head_conv
,
kernel_size
=
final_kernel
,
stride
=
1
,
padding
=
final_kernel
//
2
,
bias
=
bias
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
))
c_in
=
head_conv
conv_layers
.
append
(
build_conv_layer
(
conv_cfg
,
head_conv
,
classes
,
kernel_size
=
final_kernel
,
stride
=
1
,
padding
=
final_kernel
//
2
,
bias
=
True
))
conv_layers
=
nn
.
Sequential
(
*
conv_layers
)
self
.
__setattr__
(
head
,
conv_layers
)
if
init_cfg
is
None
:
self
.
init_cfg
=
dict
(
type
=
'Kaiming'
,
layer
=
'Conv2d'
)
def
init_weights
(
self
):
"""Initialize weights."""
super
().
init_weights
()
for
head
in
self
.
heads
:
if
head
==
'heatmap'
:
self
.
__getattr__
(
head
)[
-
1
].
bias
.
data
.
fill_
(
self
.
init_bias
)
def
forward
(
self
,
x
):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
ret_dict
=
dict
()
for
head
in
self
.
heads
:
ret_dict
[
head
]
=
self
.
__getattr__
(
head
)(
x
)
return
ret_dict
@
HEADS
.
register_module
(
force
=
True
)
class
DCNSeparateHead
(
BaseModule
):
r
"""DCNSeparateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
num_cls (int): Number of classes.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv
layer. Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
# noqa: W605
def
__init__
(
self
,
in_channels
,
num_cls
,
heads
,
dcn_config
,
head_conv
=
64
,
final_kernel
=
1
,
init_bias
=-
2.19
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
dict
(
type
=
'BN2d'
),
bias
=
'auto'
,
init_cfg
=
None
,
**
kwargs
):
assert
init_cfg
is
None
,
'To prevent abnormal initialization '
\
'behavior, init_cfg is not allowed to be set'
super
(
DCNSeparateHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
'heatmap'
in
heads
:
heads
.
pop
(
'heatmap'
)
# feature adaptation with dcn
# use separate features for classification / regression
self
.
feature_adapt_cls
=
build_conv_layer
(
dcn_config
)
self
.
feature_adapt_reg
=
build_conv_layer
(
dcn_config
)
# heatmap prediction head
cls_head
=
[
ConvModule
(
in_channels
,
head_conv
,
kernel_size
=
3
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
bias
=
bias
,
norm_cfg
=
norm_cfg
),
build_conv_layer
(
conv_cfg
,
head_conv
,
num_cls
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
bias
)
]
self
.
cls_head
=
nn
.
Sequential
(
*
cls_head
)
self
.
init_bias
=
init_bias
# other regression target
self
.
task_head
=
SeparateHead
(
in_channels
,
heads
,
head_conv
=
head_conv
,
final_kernel
=
final_kernel
,
bias
=
bias
)
if
init_cfg
is
None
:
self
.
init_cfg
=
dict
(
type
=
'Kaiming'
,
layer
=
'Conv2d'
)
def
init_weights
(
self
):
"""Initialize weights."""
super
().
init_weights
()
self
.
cls_head
[
-
1
].
bias
.
data
.
fill_
(
self
.
init_bias
)
def
forward
(
self
,
x
):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
center_feat
=
self
.
feature_adapt_cls
(
x
)
reg_feat
=
self
.
feature_adapt_reg
(
x
)
cls_score
=
self
.
cls_head
(
center_feat
)
ret
=
self
.
task_head
(
reg_feat
)
ret
[
'heatmap'
]
=
cls_score
return
ret
@
HEADS
.
register_module
()
class
BEV_CenterHead
(
BaseModule
):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def
__init__
(
self
,
in_channels
=
[
128
],
tasks
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
bbox_coder
=
None
,
common_heads
=
dict
(),
loss_cls
=
dict
(
type
=
'GaussianFocalLoss'
,
reduction
=
'mean'
),
loss_bbox
=
dict
(
type
=
'L1Loss'
,
reduction
=
'none'
,
loss_weight
=
0.25
),
separate_head
=
dict
(
type
=
'SeparateHead'
,
init_bias
=-
2.19
,
final_kernel
=
3
),
share_conv_channel
=
64
,
num_heatmap_convs
=
2
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
dict
(
type
=
'BN2d'
),
bias
=
'auto'
,
norm_bbox
=
True
,
init_cfg
=
None
,
task_specific
=
True
):
assert
init_cfg
is
None
,
'To prevent abnormal initialization '
\
'behavior, init_cfg is not allowed to be set'
super
(
BEV_CenterHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
num_classes
=
[
len
(
t
[
'class_names'
])
for
t
in
tasks
]
# 记录不同task(SeparateHead)负责检测的类别数.
self
.
class_names
=
[
t
[
'class_names'
]
for
t
in
tasks
]
# 记录不同task(SeparateHead)负责检测的类别名.
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
norm_bbox
=
norm_bbox
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_bbox
=
build_loss
(
loss_bbox
)
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
num_anchor_per_locs
=
[
n
for
n
in
num_classes
]
self
.
fp16_enabled
=
False
# a shared convolution
self
.
shared_conv
=
ConvModule
(
in_channels
,
share_conv_channel
,
kernel_size
=
3
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
bias
=
bias
)
# 每个task建立对应的head.
self
.
task_heads
=
nn
.
ModuleList
()
for
num_cls
in
num_classes
:
# common_heads = dict(
# reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
heads
=
copy
.
deepcopy
(
common_heads
)
heads
.
update
(
dict
(
heatmap
=
(
num_cls
,
num_heatmap_convs
)))
separate_head
.
update
(
in_channels
=
share_conv_channel
,
heads
=
heads
,
num_cls
=
num_cls
)
self
.
task_heads
.
append
(
builder
.
build_head
(
separate_head
))
self
.
with_velocity
=
'vel'
in
common_heads
.
keys
()
self
.
task_specific
=
task_specific
def
forward_single
(
self
,
x
):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts
=
[]
x
=
self
.
shared_conv
(
x
)
# (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
for
task
in
self
.
task_heads
:
ret_dicts
.
append
(
task
(
x
))
# ret_dicts: [dict0, dict1, ...] len = SeparateHead的数量
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
return
ret_dicts
def
forward
(
self
,
feats
):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
results: Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
"""
return
multi_apply
(
self
.
forward_single
,
feats
)
def
_gather_feat
(
self
,
feat
,
ind
,
mask
=
None
):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim
=
feat
.
size
(
2
)
ind
=
ind
.
unsqueeze
(
2
).
expand
(
ind
.
size
(
0
),
ind
.
size
(
1
),
dim
)
feat
=
feat
.
gather
(
1
,
ind
)
if
mask
is
not
None
:
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
feat
)
feat
=
feat
[
mask
]
feat
=
feat
.
view
(
-
1
,
dim
)
return
feat
def
get_targets
(
self
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
Returns:
Returns:
tuple[list[torch.Tensor]]: (
heatmaps: List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
anno_boxes:
inds:
masks:
)
"""
heatmaps
,
anno_boxes
,
inds
,
masks
=
multi_apply
(
self
.
get_targets_single
,
gt_bboxes_3d
,
gt_labels_3d
)
# heatmaps: # Tuple(List[(N_cls0, H, W), (N_cls1, H, W), ...], ...) len = batch_size
# anno_boxes: # Tuple(List[(max_objs, 10), (max_objs, 10), ...], ...) len = batch_size
# inds: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# masks: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# Transpose heatmaps
# List[List[(N_cls0, H, W), (N_cls0, H, W), ...], List[(N_cls1, H, W), (N_cls1, H, W), ...], ...] len = num of SeparateHead
heatmaps
=
list
(
map
(
list
,
zip
(
*
heatmaps
)))
heatmaps
=
[
torch
.
stack
(
hms_
)
for
hms_
in
heatmaps
]
# List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# Transpose anno_boxes
anno_boxes
=
list
(
map
(
list
,
zip
(
*
anno_boxes
)))
anno_boxes
=
[
torch
.
stack
(
anno_boxes_
)
for
anno_boxes_
in
anno_boxes
]
# List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# Transpose inds
inds
=
list
(
map
(
list
,
zip
(
*
inds
)))
inds
=
[
torch
.
stack
(
inds_
)
for
inds_
in
inds
]
# List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# Transpose inds
masks
=
list
(
map
(
list
,
zip
(
*
masks
)))
masks
=
[
torch
.
stack
(
masks_
)
for
masks_
in
masks
]
# List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
return
heatmaps
,
anno_boxes
,
inds
,
masks
def
get_targets_single
(
self
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. # (N_gt, 7/9)
gt_labels_3d (torch.Tensor): Labels of boxes. # (N_gt, )
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- heatmaps: list[torch.Tensor]: Heatmap scores. # List[(N_cls0, H, W), (N_cls1, H, W), ...]
len = num of tasks
- anno_boxes: list[torch.Tensor]: Ground truth boxes. # List[(max_objs, 10), (max_objs, 10), ...]
- inds: list[torch.Tensor]: Indexes indicating the position
of the valid boxes. # List[(max_objs, ), (max_objs, ), ...]
- masks: list[torch.Tensor]: Masks indicating which boxes
are valid. # List[(max_objs, ), (max_objs, ), ...]
"""
device
=
gt_labels_3d
.
device
gt_bboxes_3d
=
torch
.
cat
(
(
gt_bboxes_3d
.
gravity_center
,
gt_bboxes_3d
.
tensor
[:,
3
:]),
dim
=
1
).
to
(
device
)
# (N_gt, 7/9)
max_objs
=
self
.
train_cfg
[
'max_objs'
]
*
self
.
train_cfg
[
'dense_reg'
]
grid_size
=
torch
.
tensor
(
self
.
train_cfg
[
'grid_size'
])
# (Dx, Dy, Dz)
pc_range
=
torch
.
tensor
(
self
.
train_cfg
[
'point_cloud_range'
])
voxel_size
=
torch
.
tensor
(
self
.
train_cfg
[
'voxel_size'
])
feature_map_size
=
grid_size
[:
2
]
//
self
.
train_cfg
[
'out_size_factor'
]
# (W, H)
# reorganize the gt_dict by tasks
task_masks
=
[]
flag
=
0
for
class_name
in
self
.
class_names
:
# class_name: 不同task(SeparateHead)负责检测的类别名.
task_masks
.
append
([
torch
.
where
(
gt_labels_3d
==
class_name
.
index
(
i
)
+
flag
)
for
i
in
class_name
])
flag
+=
len
(
class_name
)
# task_masks: List[task_mask0, task_mask1, ...] len = number of SeparateHeads
# task_mask: List[((N_gt0, ), ), ((N_gt1, ), ), ...] len = number of class
task_boxes
=
[]
task_classes
=
[]
flag2
=
0
for
idx
,
mask
in
enumerate
(
task_masks
):
# mask: 不同task(SeparateHead)的mask, 每个task负责检测一组不同类别的目标.
# List[((N_gt0, ), ), ((N_gt1, ), ), ...], # N_gt_task=N_gt0+N_gt1+..., 表示当前task负责检测的gt_boxes的数量.
task_box
=
[]
task_class
=
[]
for
m
in
mask
:
task_box
.
append
(
gt_bboxes_3d
[
m
])
# 0 is background for each task, so we need to add 1 here.
task_class
.
append
(
gt_labels_3d
[
m
]
+
1
-
flag2
)
task_boxes
.
append
(
torch
.
cat
(
task_box
,
axis
=
0
).
to
(
device
))
task_classes
.
append
(
torch
.
cat
(
task_class
).
long
().
to
(
device
))
flag2
+=
len
(
mask
)
# 记录不同task负责检测的gt_boxes和gt_classes:
# task_boxes: List[(N_gt_task0, 7/9), (N_gt_task1, 7/9), ...]
# task_classes: List[(N_gt_task0, ), (N_gt_task1, ), ...]
draw_gaussian
=
draw_heatmap_gaussian
heatmaps
,
anno_boxes
,
inds
,
masks
=
[],
[],
[],
[]
for
idx
,
task_head
in
enumerate
(
self
.
task_heads
):
heatmap
=
gt_bboxes_3d
.
new_zeros
(
(
len
(
self
.
class_names
[
idx
]),
feature_map_size
[
1
],
feature_map_size
[
0
]))
# (N_cls, H, W) N_cls表示当前task_head负责检测的类别数目.
if
self
.
with_velocity
:
anno_box
=
gt_bboxes_3d
.
new_zeros
((
max_objs
,
10
),
dtype
=
torch
.
float32
)
# (max_objs, 10)
else
:
anno_box
=
gt_bboxes_3d
.
new_zeros
((
max_objs
,
8
),
dtype
=
torch
.
float32
)
ind
=
gt_labels_3d
.
new_zeros
((
max_objs
,
),
dtype
=
torch
.
int64
)
# (max_objs, )
mask
=
gt_bboxes_3d
.
new_zeros
((
max_objs
,
),
dtype
=
torch
.
uint8
)
# (max_objs, )
num_objs
=
min
(
task_boxes
[
idx
].
shape
[
0
],
max_objs
)
# 当前task_head负责检测的目标.
for
k
in
range
(
num_objs
):
cls_id
=
task_classes
[
idx
][
k
]
-
1
# 当前目标的cls_id, cls_id是相对task group内的.
width
=
task_boxes
[
idx
][
k
][
3
]
# dx
length
=
task_boxes
[
idx
][
k
][
4
]
# dy
# 当前目标在feature map上的width和length
width
=
width
/
voxel_size
[
0
]
/
self
.
train_cfg
[
'out_size_factor'
]
length
=
length
/
voxel_size
[
1
]
/
self
.
train_cfg
[
'out_size_factor'
]
if
width
>
0
and
length
>
0
:
# 计算gaussian半径
radius
=
gaussian_radius
(
(
length
,
width
),
min_overlap
=
self
.
train_cfg
[
'gaussian_overlap'
])
radius
=
max
(
self
.
train_cfg
[
'min_radius'
],
int
(
radius
))
# be really careful for the coordinate system of
# your box annotation.
x
,
y
,
z
=
task_boxes
[
idx
][
k
][
0
],
task_boxes
[
idx
][
k
][
1
],
task_boxes
[
idx
][
k
][
2
]
# 当前目标的中心坐标.
# 计算gt_box中心点在feature map中对应的位置.
coor_x
=
(
x
-
pc_range
[
0
]
)
/
voxel_size
[
0
]
/
self
.
train_cfg
[
'out_size_factor'
]
coor_y
=
(
y
-
pc_range
[
1
]
)
/
voxel_size
[
1
]
/
self
.
train_cfg
[
'out_size_factor'
]
center
=
torch
.
tensor
([
coor_x
,
coor_y
],
dtype
=
torch
.
float32
,
device
=
device
)
center_int
=
center
.
to
(
torch
.
int32
)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if
not
(
0
<=
center_int
[
0
]
<
feature_map_size
[
0
]
and
0
<=
center_int
[
1
]
<
feature_map_size
[
1
]):
continue
# 根据目标中心点在feature map中对应的位置、高斯半径来设置heatmap.
draw_gaussian
(
heatmap
[
cls_id
],
center_int
,
radius
)
new_idx
=
k
x
,
y
=
center_int
[
0
],
center_int
[
1
]
assert
(
y
*
feature_map_size
[
0
]
+
x
<
feature_map_size
[
0
]
*
feature_map_size
[
1
])
# 记录正样本在feature map中的位置.
ind
[
new_idx
]
=
y
*
feature_map_size
[
0
]
+
x
mask
[
new_idx
]
=
1
# TODO: support other outdoor dataset
rot
=
task_boxes
[
idx
][
k
][
6
]
box_dim
=
task_boxes
[
idx
][
k
][
3
:
6
]
if
self
.
norm_bbox
:
box_dim
=
box_dim
.
log
()
if
self
.
with_velocity
:
vx
,
vy
=
task_boxes
[
idx
][
k
][
7
:]
anno_box
[
new_idx
]
=
torch
.
cat
([
center
-
torch
.
tensor
([
x
,
y
],
device
=
device
),
# tx, ty
z
.
unsqueeze
(
0
),
box_dim
,
# z, log(dx), log(dy), log(dz)
torch
.
sin
(
rot
).
unsqueeze
(
0
),
# sin(rot)
torch
.
cos
(
rot
).
unsqueeze
(
0
),
# cos(rot)
vx
.
unsqueeze
(
0
),
# vx
vy
.
unsqueeze
(
0
)
# vy
])
# [tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy]
else
:
anno_box
[
new_idx
]
=
torch
.
cat
([
center
-
torch
.
tensor
([
x
,
y
],
device
=
device
),
z
.
unsqueeze
(
0
),
box_dim
,
torch
.
sin
(
rot
).
unsqueeze
(
0
),
torch
.
cos
(
rot
).
unsqueeze
(
0
)
])
heatmaps
.
append
(
heatmap
)
# append (N_cls, H, W)
anno_boxes
.
append
(
anno_box
)
# append (max_objs, 10)
masks
.
append
(
mask
)
# append (max_objs, )
inds
.
append
(
ind
)
# append (max_objs, )
return
heatmaps
,
anno_boxes
,
inds
,
masks
def
loss
(
self
,
gt_bboxes_3d
,
gt_labels_3d
,
preds_dicts
,
**
kwargs
):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
preds_dicts (dict): Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps
,
anno_boxes
,
inds
,
masks
=
self
.
get_targets
(
gt_bboxes_3d
,
gt_labels_3d
)
# heatmaps: # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# anno_boxes: # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# inds: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# masks: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
loss_dict
=
dict
()
if
not
self
.
task_specific
:
loss_dict
[
'loss'
]
=
0
for
task_id
,
preds_dict
in
enumerate
(
preds_dicts
):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
# heatmap focal loss
preds_dict
[
0
][
'heatmap'
]
=
clip_sigmoid
(
preds_dict
[
0
][
'heatmap'
])
num_pos
=
heatmaps
[
task_id
].
eq
(
1
).
float
().
sum
().
item
()
cls_avg_factor
=
torch
.
clamp
(
reduce_mean
(
heatmaps
[
task_id
].
new_tensor
(
num_pos
)),
min
=
1
).
item
()
loss_heatmap
=
self
.
loss_cls
(
preds_dict
[
0
][
'heatmap'
],
# (B, cur_N_cls, H, W)
heatmaps
[
task_id
],
# (B, cur_N_cls, H, W)
avg_factor
=
cls_avg_factor
)
# (B, max_objs, 10) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
target_box
=
anno_boxes
[
task_id
]
# reconstruct the anno_box from multiple reg heads
preds_dict
[
0
][
'anno_box'
]
=
torch
.
cat
(
(
preds_dict
[
0
][
'reg'
],
preds_dict
[
0
][
'height'
],
preds_dict
[
0
][
'dim'
],
preds_dict
[
0
][
'rot'
],
preds_dict
[
0
][
'vel'
],
),
dim
=
1
,
)
# (B, 10, H, W) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
# Regression loss for dimension, offset, height, rotation
num
=
masks
[
task_id
].
float
().
sum
()
# 正样本的数量
ind
=
inds
[
task_id
]
# (B, max_objs)
pred
=
preds_dict
[
0
][
'anno_box'
].
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# (B, H, W, 10)
pred
=
pred
.
view
(
pred
.
size
(
0
),
-
1
,
pred
.
size
(
3
))
# (B, H*W, 10)
pred
=
self
.
_gather_feat
(
pred
,
ind
)
# (B, max_objs, 10)
# (B, max_objs) --> (B, max_objs, 1) --> (B, max_objs, 10)
mask
=
masks
[
task_id
].
unsqueeze
(
2
).
expand_as
(
target_box
).
float
()
num
=
torch
.
clamp
(
reduce_mean
(
target_box
.
new_tensor
(
num
)),
min
=
1e-4
).
item
()
isnotnan
=
(
~
torch
.
isnan
(
target_box
)).
float
()
mask
*=
isnotnan
# 只监督mask指定的reg预测.
code_weights
=
self
.
train_cfg
[
'code_weights'
]
bbox_weights
=
mask
*
mask
.
new_tensor
(
code_weights
)
# 在mask基础上,设置box不同属性的权重. (B, max_objs, 10)
if
self
.
task_specific
:
name_list
=
[
'xy'
,
'z'
,
'whl'
,
'yaw'
,
'vel'
]
clip_index
=
[
0
,
2
,
3
,
6
,
8
,
10
]
for
reg_task_id
in
range
(
len
(
name_list
)):
pred_tmp
=
pred
[...,
clip_index
[
reg_task_id
]:
clip_index
[
reg_task_id
+
1
]]
# (B, max_objs, K)
target_box_tmp
=
target_box
[...,
clip_index
[
reg_task_id
]:
clip_index
[
reg_task_id
+
1
]]
# (B, max_objs, K)
bbox_weights_tmp
=
bbox_weights
[...,
clip_index
[
reg_task_id
]:
clip_index
[
reg_task_id
+
1
]]
# (B, max_objs, K)
loss_bbox_tmp
=
self
.
loss_bbox
(
pred_tmp
,
target_box_tmp
,
bbox_weights_tmp
,
avg_factor
=
(
num
+
1e-4
))
loss_dict
[
f
'task
{
task_id
}
.loss_%s'
%
(
name_list
[
reg_task_id
])]
=
loss_bbox_tmp
loss_dict
[
f
'task
{
task_id
}
.loss_heatmap'
]
=
loss_heatmap
else
:
loss_bbox
=
self
.
loss_bbox
(
pred
,
target_box
,
bbox_weights
,
avg_factor
=
num
)
loss_dict
[
'loss'
]
+=
loss_bbox
loss_dict
[
'loss'
]
+=
loss_heatmap
return
loss_dict
def
get_bboxes
(
self
,
preds_dicts
,
img_metas
,
img
=
None
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
ret_list: List[p_list0, p_list1, ...]
p_list: List[(N, 9), (N, ), (N, )]
"""
rets
=
[]
for
task_id
,
preds_dict
in
enumerate
(
preds_dicts
):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
batch_size
=
preds_dict
[
0
][
'heatmap'
].
shape
[
0
]
batch_heatmap
=
preds_dict
[
0
][
'heatmap'
].
sigmoid
()
# (B, n_cls, H, W)
batch_reg
=
preds_dict
[
0
][
'reg'
]
# (B, 2, H, W)
batch_hei
=
preds_dict
[
0
][
'height'
]
# (B, 1, H, W)
if
self
.
norm_bbox
:
batch_dim
=
torch
.
exp
(
preds_dict
[
0
][
'dim'
])
# (B, 3, H, W)
else
:
batch_dim
=
preds_dict
[
0
][
'dim'
]
batch_rots
=
preds_dict
[
0
][
'rot'
][:,
0
].
unsqueeze
(
1
)
# (B, 1, H, W)
batch_rotc
=
preds_dict
[
0
][
'rot'
][:,
1
].
unsqueeze
(
1
)
# (B, 1, H, W)
if
'vel'
in
preds_dict
[
0
]:
batch_vel
=
preds_dict
[
0
][
'vel'
]
# (B, 2, H, W)
else
:
batch_vel
=
None
temp
=
self
.
bbox_coder
.
decode
(
batch_heatmap
,
batch_rots
,
batch_rotc
,
batch_hei
,
batch_dim
,
batch_vel
,
reg
=
batch_reg
,
task_id
=
task_id
)
# temp: List[p_dict0, p_dict1, ...] len=bs
# p_dict = {
# 'bboxes': boxes3d, # (K', 9)
# 'scores': scores, # (K', )
# 'labels': labels # (K', )
# }
batch_reg_preds
=
[
box
[
'bboxes'
]
for
box
in
temp
]
# List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds
=
[
box
[
'scores'
]
for
box
in
temp
]
# List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels
=
[
box
[
'labels'
]
for
box
in
temp
]
# List[(K0, ), (K1, ), ...] len = bs
nms_type
=
self
.
test_cfg
.
get
(
'nms_type'
)
if
isinstance
(
nms_type
,
list
):
nms_type
=
nms_type
[
task_id
]
if
nms_type
==
'circle'
:
ret_task
=
[]
for
i
in
range
(
batch_size
):
boxes3d
=
temp
[
i
][
'bboxes'
]
scores
=
temp
[
i
][
'scores'
]
labels
=
temp
[
i
][
'labels'
]
centers
=
boxes3d
[:,
[
0
,
1
]]
boxes
=
torch
.
cat
([
centers
,
scores
.
view
(
-
1
,
1
)],
dim
=
1
)
keep
=
torch
.
tensor
(
circle_nms
(
boxes
.
detach
().
cpu
().
numpy
(),
self
.
test_cfg
[
'min_radius'
][
task_id
],
post_max_size
=
self
.
test_cfg
[
'post_max_size'
]),
dtype
=
torch
.
long
,
device
=
boxes
.
device
)
boxes3d
=
boxes3d
[
keep
]
scores
=
scores
[
keep
]
labels
=
labels
[
keep
]
ret
=
dict
(
bboxes
=
boxes3d
,
scores
=
scores
,
labels
=
labels
)
ret_task
.
append
(
ret
)
rets
.
append
(
ret_task
)
else
:
rets
.
append
(
self
.
get_task_detections
(
batch_cls_preds
,
batch_reg_preds
,
batch_cls_labels
,
img_metas
,
task_id
))
# rets: List[ret_task0, ret_task1, ...], len = num_tasks
# ret_task: List[p_dict0, p_dict1, ...], len = batch_size
# p_dict: dict{
# bboxes: (K', 9)
# scores: (K', )
# labels: (K', )
# }
# Merge branches results
num_samples
=
len
(
rets
[
0
])
# bs
ret_list
=
[]
# 遍历batch, 然后汇总所有task的预测.
for
i
in
range
(
num_samples
):
for
k
in
rets
[
0
][
i
].
keys
():
if
k
==
'bboxes'
:
bboxes
=
torch
.
cat
([
ret
[
i
][
k
]
for
ret
in
rets
])
# 对于bboxes, 直接拼接即可.
bboxes
[:,
2
]
=
bboxes
[:,
2
]
-
bboxes
[:,
5
]
*
0.5
bboxes
=
img_metas
[
i
][
'box_type_3d'
](
bboxes
,
self
.
bbox_coder
.
code_size
)
elif
k
==
'scores'
:
scores
=
torch
.
cat
([
ret
[
i
][
k
]
for
ret
in
rets
])
# 对于scores, 直接拼接即可.
elif
k
==
'labels'
:
flag
=
0
for
j
,
num_class
in
enumerate
(
self
.
num_classes
):
# 对于labels, 要进行调整, 因为预测的label是task组内的.
rets
[
j
][
i
][
k
]
+=
flag
flag
+=
num_class
labels
=
torch
.
cat
([
ret
[
i
][
k
].
int
()
for
ret
in
rets
])
ret_list
.
append
([
bboxes
,
scores
,
labels
])
# ret_list: List[p_list0, p_list1, ...]
# p_list: List[(N, 9), (N, ), (N, )]
return
ret_list
def
get_task_detections
(
self
,
batch_cls_preds
,
batch_reg_preds
,
batch_cls_labels
,
img_metas
,
task_id
):
"""Rotate nms for each task.
Args:
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9]. # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
List[p_dict0, p_dict1, ...] len = batch_size
p_dict: dict{
bboxes: (K', 9)
scores: (K', )
labels: (K', )
}
"""
predictions_dicts
=
[]
# 遍历不同batch的topK预测输出.
for
i
,
(
box_preds
,
cls_preds
,
cls_labels
)
in
enumerate
(
zip
(
batch_reg_preds
,
batch_cls_preds
,
batch_cls_labels
)):
# box_preds: (K, 9)
# cls_preds: (K, )
# cls_labels: (K, )
default_val
=
[
1.0
for
_
in
range
(
len
(
self
.
task_heads
))]
factor
=
self
.
test_cfg
.
get
(
'nms_rescale_factor'
,
default_val
)[
task_id
]
if
isinstance
(
factor
,
list
):
# List[float, float, ..] len = 当前task负责预测的类别数.
# 对于box_preds, 使用其对应的factor进行缩放, 一般是放大小目标,缩小大目标.
for
cid
in
range
(
len
(
factor
)):
box_preds
[
cls_labels
==
cid
,
3
:
6
]
=
\
box_preds
[
cls_labels
==
cid
,
3
:
6
]
*
factor
[
cid
]
else
:
box_preds
[:,
3
:
6
]
=
box_preds
[:,
3
:
6
]
*
factor
# Apply NMS in birdeye view
top_labels
=
cls_labels
.
long
()
# (K, )
top_scores
=
cls_preds
.
squeeze
(
-
1
)
if
cls_preds
.
shape
[
0
]
>
1
\
else
cls_preds
# (K, )
if
top_scores
.
shape
[
0
]
!=
0
:
boxes_for_nms
=
img_metas
[
i
][
'box_type_3d'
](
box_preds
[:,
:],
self
.
bbox_coder
.
code_size
).
bev
# (K, 5) (x, y, dx, dy, yaw)
# the nms in 3d detection just remove overlap boxes.
if
isinstance
(
self
.
test_cfg
[
'nms_thr'
],
list
):
nms_thresh
=
self
.
test_cfg
[
'nms_thr'
][
task_id
]
else
:
nms_thresh
=
self
.
test_cfg
[
'nms_thr'
]
selected
=
nms_bev
(
boxes_for_nms
,
top_scores
,
thresh
=
nms_thresh
,
pre_max_size
=
self
.
test_cfg
[
'pre_max_size'
],
post_max_size
=
self
.
test_cfg
[
'post_max_size'
],
xyxyr2xywhr
=
False
,
)
else
:
selected
=
[]
# NMS后再根据factor缩放回原来的尺寸.
if
isinstance
(
factor
,
list
):
for
cid
in
range
(
len
(
factor
)):
box_preds
[
top_labels
==
cid
,
3
:
6
]
=
\
box_preds
[
top_labels
==
cid
,
3
:
6
]
/
factor
[
cid
]
else
:
box_preds
[:,
3
:
6
]
=
box_preds
[:,
3
:
6
]
/
factor
# if selected is not None:
selected_boxes
=
box_preds
[
selected
]
# (K', 9)
selected_labels
=
top_labels
[
selected
]
# (K', )
selected_scores
=
top_scores
[
selected
]
# (K', )
# finally generate predictions.
if
selected_boxes
.
shape
[
0
]
!=
0
:
predictions_dict
=
dict
(
bboxes
=
selected_boxes
,
scores
=
selected_scores
,
labels
=
selected_labels
)
else
:
dtype
=
batch_reg_preds
[
0
].
dtype
device
=
batch_reg_preds
[
0
].
device
predictions_dict
=
dict
(
bboxes
=
torch
.
zeros
([
0
,
self
.
bbox_coder
.
code_size
],
dtype
=
dtype
,
device
=
device
),
scores
=
torch
.
zeros
([
0
],
dtype
=
dtype
,
device
=
device
),
labels
=
torch
.
zeros
([
0
],
dtype
=
top_labels
.
dtype
,
device
=
device
))
predictions_dicts
.
append
(
predictions_dict
)
return
predictions_dicts
@
HEADS
.
register_module
()
class
Centerness_Head
(
BaseModule
):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def
__init__
(
self
,
in_channels
=
[
128
],
tasks
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
bbox_coder
=
None
,
common_heads
=
dict
(),
loss_cls
=
dict
(
type
=
'GaussianFocalLoss'
,
reduction
=
'mean'
),
loss_bbox
=
dict
(
type
=
'L1Loss'
,
reduction
=
'none'
,
loss_weight
=
0.25
),
separate_head
=
dict
(
type
=
'SeparateHead'
,
init_bias
=-
2.19
,
final_kernel
=
3
),
share_conv_channel
=
64
,
num_heatmap_convs
=
2
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
norm_cfg
=
dict
(
type
=
'BN2d'
),
bias
=
'auto'
,
norm_bbox
=
True
,
init_cfg
=
None
,
task_specific
=
True
,
task_specific_weight
=
[
1
,
1
,
1
,
1
,
1
]):
assert
init_cfg
is
None
,
'To prevent abnormal initialization '
\
'behavior, init_cfg is not allowed to be set'
super
(
Centerness_Head
,
self
).
__init__
(
init_cfg
=
init_cfg
)
num_classes
=
[
len
(
t
[
'class_names'
])
for
t
in
tasks
]
# 记录不同task(SeparateHead)负责检测的类别数.
self
.
class_names
=
[
t
[
'class_names'
]
for
t
in
tasks
]
# 记录不同task(SeparateHead)负责检测的类别名.
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
norm_bbox
=
norm_bbox
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_bbox
=
build_loss
(
loss_bbox
)
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
num_anchor_per_locs
=
[
n
for
n
in
num_classes
]
self
.
fp16_enabled
=
False
# a shared convolution
self
.
shared_conv
=
ConvModule
(
in_channels
,
share_conv_channel
,
kernel_size
=
3
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
bias
=
bias
)
# 每个task建立对应的head.
self
.
task_heads
=
nn
.
ModuleList
()
for
num_cls
in
num_classes
:
# common_heads = dict(
# reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
heads
=
copy
.
deepcopy
(
common_heads
)
heads
.
update
(
dict
(
heatmap
=
(
num_cls
,
num_heatmap_convs
)))
separate_head
.
update
(
in_channels
=
share_conv_channel
,
heads
=
heads
,
num_cls
=
num_cls
)
self
.
task_heads
.
append
(
builder
.
build_head
(
separate_head
))
self
.
with_velocity
=
'vel'
in
common_heads
.
keys
()
self
.
task_specific
=
task_specific
self
.
task_specific_weight
=
task_specific_weight
# [1, 1, 0, 0, 0] # 'xy', 'z', 'whl', 'yaw', 'vel'
def
forward_single
(
self
,
x
):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts
=
[]
x
=
self
.
shared_conv
(
x
)
# (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
for
task
in
self
.
task_heads
:
ret_dicts
.
append
(
task
(
x
))
# ret_dicts: [dict0, dict1, ...] len = SeparateHead的数量
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
return
ret_dicts
def
forward
(
self
,
feats
):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
results: Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
"""
return
multi_apply
(
self
.
forward_single
,
feats
)
def
_gather_feat
(
self
,
feat
,
ind
,
mask
=
None
):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim
=
feat
.
size
(
2
)
ind
=
ind
.
unsqueeze
(
2
).
expand
(
ind
.
size
(
0
),
ind
.
size
(
1
),
dim
)
feat
=
feat
.
gather
(
1
,
ind
)
if
mask
is
not
None
:
mask
=
mask
.
unsqueeze
(
2
).
expand_as
(
feat
)
feat
=
feat
[
mask
]
feat
=
feat
.
view
(
-
1
,
dim
)
return
feat
def
get_targets
(
self
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
Returns:
Returns:
tuple[list[torch.Tensor]]: (
heatmaps: List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
anno_boxes:
inds:
masks:
)
"""
heatmaps
,
anno_boxes
,
inds
,
masks
=
multi_apply
(
self
.
get_targets_single
,
gt_bboxes_3d
,
gt_labels_3d
)
# heatmaps: # Tuple(List[(N_cls0, H, W), (N_cls1, H, W), ...], ...) len = batch_size
# anno_boxes: # Tuple(List[(max_objs, 10), (max_objs, 10), ...], ...) len = batch_size
# inds: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# masks: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# Transpose heatmaps
# List[List[(N_cls0, H, W), (N_cls0, H, W), ...], List[(N_cls1, H, W), (N_cls1, H, W), ...], ...] len = num of SeparateHead
heatmaps
=
list
(
map
(
list
,
zip
(
*
heatmaps
)))
heatmaps
=
[
torch
.
stack
(
hms_
)
for
hms_
in
heatmaps
]
# List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# Transpose anno_boxes
anno_boxes
=
list
(
map
(
list
,
zip
(
*
anno_boxes
)))
anno_boxes
=
[
torch
.
stack
(
anno_boxes_
)
for
anno_boxes_
in
anno_boxes
]
# List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# Transpose inds
inds
=
list
(
map
(
list
,
zip
(
*
inds
)))
inds
=
[
torch
.
stack
(
inds_
)
for
inds_
in
inds
]
# List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# Transpose inds
masks
=
list
(
map
(
list
,
zip
(
*
masks
)))
masks
=
[
torch
.
stack
(
masks_
)
for
masks_
in
masks
]
# List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
return
heatmaps
,
anno_boxes
,
inds
,
masks
def
get_targets_single
(
self
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. # (N_gt, 7/9)
gt_labels_3d (torch.Tensor): Labels of boxes. # (N_gt, )
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- heatmaps: list[torch.Tensor]: Heatmap scores. # List[(N_cls0, H, W), (N_cls1, H, W), ...]
len = num of tasks
- anno_boxes: list[torch.Tensor]: Ground truth boxes. # List[(max_objs, 10), (max_objs, 10), ...]
- inds: list[torch.Tensor]: Indexes indicating the position
of the valid boxes. # List[(max_objs, ), (max_objs, ), ...]
- masks: list[torch.Tensor]: Masks indicating which boxes
are valid. # List[(max_objs, ), (max_objs, ), ...]
"""
device
=
gt_labels_3d
.
device
gt_bboxes_3d
=
torch
.
cat
(
(
gt_bboxes_3d
.
gravity_center
,
gt_bboxes_3d
.
tensor
[:,
3
:]),
dim
=
1
).
to
(
device
)
# (N_gt, 7/9)
max_objs
=
self
.
train_cfg
[
'max_objs'
]
*
self
.
train_cfg
[
'dense_reg'
]
grid_size
=
torch
.
tensor
(
self
.
train_cfg
[
'grid_size'
])
# (Dx, Dy, Dz)
pc_range
=
torch
.
tensor
(
self
.
train_cfg
[
'point_cloud_range'
])
voxel_size
=
torch
.
tensor
(
self
.
train_cfg
[
'voxel_size'
])
feature_map_size
=
grid_size
[:
2
]
//
self
.
train_cfg
[
'out_size_factor'
]
# (W, H)
# reorganize the gt_dict by tasks
task_masks
=
[]
flag
=
0
for
class_name
in
self
.
class_names
:
# class_name: 不同task(SeparateHead)负责检测的类别名.
task_masks
.
append
([
torch
.
where
(
gt_labels_3d
==
class_name
.
index
(
i
)
+
flag
)
for
i
in
class_name
])
flag
+=
len
(
class_name
)
# task_masks: List[task_mask0, task_mask1, ...] len = number of SeparateHeads
# task_mask: List[((N_gt0, ), ), ((N_gt1, ), ), ...] len = number of class
task_boxes
=
[]
task_classes
=
[]
flag2
=
0
for
idx
,
mask
in
enumerate
(
task_masks
):
# mask: 不同task(SeparateHead)的mask, 每个task负责检测一组不同类别的目标.
# List[((N_gt0, ), ), ((N_gt1, ), ), ...], # N_gt_task=N_gt0+N_gt1+..., 表示当前task负责检测的gt_boxes的数量.
task_box
=
[]
task_class
=
[]
for
m
in
mask
:
task_box
.
append
(
gt_bboxes_3d
[
m
])
# 0 is background for each task, so we need to add 1 here.
task_class
.
append
(
gt_labels_3d
[
m
]
+
1
-
flag2
)
task_boxes
.
append
(
torch
.
cat
(
task_box
,
axis
=
0
).
to
(
device
))
task_classes
.
append
(
torch
.
cat
(
task_class
).
long
().
to
(
device
))
flag2
+=
len
(
mask
)
# 记录不同task负责检测的gt_boxes和gt_classes:
# task_boxes: List[(N_gt_task0, 7/9), (N_gt_task1, 7/9), ...]
# task_classes: List[(N_gt_task0, ), (N_gt_task1, ), ...]
draw_gaussian
=
draw_heatmap_gaussian
heatmaps
,
anno_boxes
,
inds
,
masks
=
[],
[],
[],
[]
for
idx
,
task_head
in
enumerate
(
self
.
task_heads
):
heatmap
=
gt_bboxes_3d
.
new_zeros
(
(
len
(
self
.
class_names
[
idx
]),
feature_map_size
[
1
],
feature_map_size
[
0
]))
# (N_cls, H, W) N_cls表示当前task_head负责检测的类别数目.
if
self
.
with_velocity
:
anno_box
=
gt_bboxes_3d
.
new_zeros
((
max_objs
,
10
),
dtype
=
torch
.
float32
)
# (max_objs, 10)
else
:
anno_box
=
gt_bboxes_3d
.
new_zeros
((
max_objs
,
8
),
dtype
=
torch
.
float32
)
ind
=
gt_labels_3d
.
new_zeros
((
max_objs
,
),
dtype
=
torch
.
int64
)
# (max_objs, )
mask
=
gt_bboxes_3d
.
new_zeros
((
max_objs
,
),
dtype
=
torch
.
uint8
)
# (max_objs, )
num_objs
=
min
(
task_boxes
[
idx
].
shape
[
0
],
max_objs
)
# 当前task_head负责检测的目标.
for
k
in
range
(
num_objs
):
cls_id
=
task_classes
[
idx
][
k
]
-
1
# 当前目标的cls_id, cls_id是相对task group内的.
width
=
task_boxes
[
idx
][
k
][
3
]
# dx
length
=
task_boxes
[
idx
][
k
][
4
]
# dy
# 当前目标在feature map上的width和length
width
=
width
/
voxel_size
[
0
]
/
self
.
train_cfg
[
'out_size_factor'
]
length
=
length
/
voxel_size
[
1
]
/
self
.
train_cfg
[
'out_size_factor'
]
if
width
>
0
and
length
>
0
:
# 计算gaussian半径
radius
=
gaussian_radius
(
(
length
,
width
),
min_overlap
=
self
.
train_cfg
[
'gaussian_overlap'
])
radius
=
max
(
self
.
train_cfg
[
'min_radius'
],
int
(
radius
))
# be really careful for the coordinate system of
# your box annotation.
x
,
y
,
z
=
task_boxes
[
idx
][
k
][
0
],
task_boxes
[
idx
][
k
][
1
],
task_boxes
[
idx
][
k
][
2
]
# 当前目标的中心坐标.
# 计算gt_box中心点在feature map中对应的位置.
coor_x
=
(
x
-
pc_range
[
0
]
)
/
voxel_size
[
0
]
/
self
.
train_cfg
[
'out_size_factor'
]
coor_y
=
(
y
-
pc_range
[
1
]
)
/
voxel_size
[
1
]
/
self
.
train_cfg
[
'out_size_factor'
]
center
=
torch
.
tensor
([
coor_x
,
coor_y
],
dtype
=
torch
.
float32
,
device
=
device
)
center_int
=
center
.
to
(
torch
.
int32
)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if
not
(
0
<=
center_int
[
0
]
<
feature_map_size
[
0
]
and
0
<=
center_int
[
1
]
<
feature_map_size
[
1
]):
continue
# 根据目标中心点在feature map中对应的位置、高斯半径来设置heatmap.
draw_gaussian
(
heatmap
[
cls_id
],
center_int
,
radius
)
new_idx
=
k
x
,
y
=
center_int
[
0
],
center_int
[
1
]
assert
(
y
*
feature_map_size
[
0
]
+
x
<
feature_map_size
[
0
]
*
feature_map_size
[
1
])
# 记录正样本在feature map中的位置.
ind
[
new_idx
]
=
y
*
feature_map_size
[
0
]
+
x
mask
[
new_idx
]
=
1
# TODO: support other outdoor dataset
rot
=
task_boxes
[
idx
][
k
][
6
]
box_dim
=
task_boxes
[
idx
][
k
][
3
:
6
]
if
self
.
norm_bbox
:
box_dim
=
box_dim
.
log
()
if
self
.
with_velocity
:
vx
,
vy
=
task_boxes
[
idx
][
k
][
7
:]
anno_box
[
new_idx
]
=
torch
.
cat
([
center
-
torch
.
tensor
([
x
,
y
],
device
=
device
),
# tx, ty
z
.
unsqueeze
(
0
),
box_dim
,
# z, log(dx), log(dy), log(dz)
torch
.
sin
(
rot
).
unsqueeze
(
0
),
# sin(rot)
torch
.
cos
(
rot
).
unsqueeze
(
0
),
# cos(rot)
vx
.
unsqueeze
(
0
),
# vx
vy
.
unsqueeze
(
0
)
# vy
])
# [tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy]
else
:
anno_box
[
new_idx
]
=
torch
.
cat
([
center
-
torch
.
tensor
([
x
,
y
],
device
=
device
),
z
.
unsqueeze
(
0
),
box_dim
,
torch
.
sin
(
rot
).
unsqueeze
(
0
),
torch
.
cos
(
rot
).
unsqueeze
(
0
)
])
heatmaps
.
append
(
heatmap
)
# append (N_cls, H, W)
anno_boxes
.
append
(
anno_box
)
# append (max_objs, 10)
masks
.
append
(
mask
)
# append (max_objs, )
inds
.
append
(
ind
)
# append (max_objs, )
return
heatmaps
,
anno_boxes
,
inds
,
masks
def
loss
(
self
,
gt_bboxes_3d
,
gt_labels_3d
,
preds_dicts
,
**
kwargs
):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
preds_dicts (dict): Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps
,
anno_boxes
,
inds
,
masks
=
self
.
get_targets
(
gt_bboxes_3d
,
gt_labels_3d
)
# heatmaps: # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# anno_boxes: # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# inds: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# masks: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
loss_dict
=
dict
()
if
not
self
.
task_specific
:
loss_dict
[
'loss'
]
=
0
for
task_id
,
preds_dict
in
enumerate
(
preds_dicts
):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
# heatmap focal loss
preds_dict
[
0
][
'heatmap'
]
=
clip_sigmoid
(
preds_dict
[
0
][
'heatmap'
])
num_pos
=
heatmaps
[
task_id
].
eq
(
1
).
float
().
sum
().
item
()
cls_avg_factor
=
torch
.
clamp
(
reduce_mean
(
heatmaps
[
task_id
].
new_tensor
(
num_pos
)),
min
=
1
).
item
()
loss_heatmap
=
self
.
loss_cls
(
preds_dict
[
0
][
'heatmap'
],
# (B, cur_N_cls, H, W)
heatmaps
[
task_id
],
# (B, cur_N_cls, H, W)
avg_factor
=
cls_avg_factor
)
# (B, max_objs, 10) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
target_box
=
anno_boxes
[
task_id
]
# reconstruct the anno_box from multiple reg heads
preds_dict
[
0
][
'anno_box'
]
=
torch
.
cat
(
(
preds_dict
[
0
][
'reg'
],
preds_dict
[
0
][
'height'
],
preds_dict
[
0
][
'dim'
],
preds_dict
[
0
][
'rot'
],
preds_dict
[
0
][
'vel'
],
),
dim
=
1
,
)
# (B, 10, H, W) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
# Regression loss for dimension, offset, height, rotation
num
=
masks
[
task_id
].
float
().
sum
()
# 正样本的数量
ind
=
inds
[
task_id
]
# (B, max_objs)
pred
=
preds_dict
[
0
][
'anno_box'
].
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# (B, H, W, 10)
pred
=
pred
.
view
(
pred
.
size
(
0
),
-
1
,
pred
.
size
(
3
))
# (B, H*W, 10)
pred
=
self
.
_gather_feat
(
pred
,
ind
)
# (B, max_objs, 10)
# (B, max_objs) --> (B, max_objs, 1) --> (B, max_objs, 10)
mask
=
masks
[
task_id
].
unsqueeze
(
2
).
expand_as
(
target_box
).
float
()
num
=
torch
.
clamp
(
reduce_mean
(
target_box
.
new_tensor
(
num
)),
min
=
1e-4
).
item
()
isnotnan
=
(
~
torch
.
isnan
(
target_box
)).
float
()
mask
*=
isnotnan
# 只监督mask指定的reg预测.
code_weights
=
self
.
train_cfg
[
'code_weights'
]
bbox_weights
=
mask
*
mask
.
new_tensor
(
code_weights
)
# 在mask基础上,设置box不同属性的权重. (B, max_objs, 10)
if
self
.
task_specific
:
name_list
=
[
'xy'
,
'z'
,
'whl'
,
'yaw'
,
'vel'
]
clip_index
=
[
0
,
2
,
3
,
6
,
8
,
10
]
for
reg_task_id
in
range
(
len
(
name_list
)):
pred_tmp
=
pred
[...,
clip_index
[
reg_task_id
]:
clip_index
[
reg_task_id
+
1
]]
# (B, max_objs, K)
target_box_tmp
=
target_box
[...,
clip_index
[
reg_task_id
]:
clip_index
[
reg_task_id
+
1
]]
# (B, max_objs, K)
bbox_weights_tmp
=
bbox_weights
[...,
clip_index
[
reg_task_id
]:
clip_index
[
reg_task_id
+
1
]]
# (B, max_objs, K)
loss_bbox_tmp
=
self
.
loss_bbox
(
pred_tmp
,
target_box_tmp
,
bbox_weights_tmp
,
avg_factor
=
(
num
+
1e-4
))
loss_dict
[
f
'task
{
task_id
}
.loss_%s'
%
(
name_list
[
reg_task_id
])]
=
loss_bbox_tmp
*
self
.
task_specific_weight
[
reg_task_id
]
loss_dict
[
f
'task
{
task_id
}
.loss_heatmap'
]
=
loss_heatmap
else
:
loss_bbox
=
self
.
loss_bbox
(
pred
,
target_box
,
bbox_weights
,
avg_factor
=
num
)
loss_dict
[
'loss'
]
+=
loss_bbox
loss_dict
[
'loss'
]
+=
loss_heatmap
return
loss_dict
def
get_bboxes
(
self
,
preds_dicts
,
img_metas
,
img
=
None
,
rescale
=
False
):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
ret_list: List[p_list0, p_list1, ...]
p_list: List[(N, 9), (N, ), (N, )]
"""
rets
=
[]
for
task_id
,
preds_dict
in
enumerate
(
preds_dicts
):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
batch_size
=
preds_dict
[
0
][
'heatmap'
].
shape
[
0
]
batch_heatmap
=
preds_dict
[
0
][
'heatmap'
].
sigmoid
()
# (B, n_cls, H, W)
batch_reg
=
preds_dict
[
0
][
'reg'
]
# (B, 2, H, W)
batch_hei
=
preds_dict
[
0
][
'height'
]
# (B, 1, H, W)
if
self
.
norm_bbox
:
batch_dim
=
torch
.
exp
(
preds_dict
[
0
][
'dim'
])
# (B, 3, H, W)
else
:
batch_dim
=
preds_dict
[
0
][
'dim'
]
batch_rots
=
preds_dict
[
0
][
'rot'
][:,
0
].
unsqueeze
(
1
)
# (B, 1, H, W)
batch_rotc
=
preds_dict
[
0
][
'rot'
][:,
1
].
unsqueeze
(
1
)
# (B, 1, H, W)
if
'vel'
in
preds_dict
[
0
]:
batch_vel
=
preds_dict
[
0
][
'vel'
]
# (B, 2, H, W)
else
:
batch_vel
=
None
temp
=
self
.
bbox_coder
.
decode
(
batch_heatmap
,
batch_rots
,
batch_rotc
,
batch_hei
,
batch_dim
,
batch_vel
,
reg
=
batch_reg
,
task_id
=
task_id
)
# temp: List[p_dict0, p_dict1, ...] len=bs
# p_dict = {
# 'bboxes': boxes3d, # (K', 9)
# 'scores': scores, # (K', )
# 'labels': labels # (K', )
# }
batch_reg_preds
=
[
box
[
'bboxes'
]
for
box
in
temp
]
# List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds
=
[
box
[
'scores'
]
for
box
in
temp
]
# List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels
=
[
box
[
'labels'
]
for
box
in
temp
]
# List[(K0, ), (K1, ), ...] len = bs
nms_type
=
self
.
test_cfg
.
get
(
'nms_type'
)
if
isinstance
(
nms_type
,
list
):
nms_type
=
nms_type
[
task_id
]
if
nms_type
==
'circle'
:
ret_task
=
[]
for
i
in
range
(
batch_size
):
boxes3d
=
temp
[
i
][
'bboxes'
]
scores
=
temp
[
i
][
'scores'
]
labels
=
temp
[
i
][
'labels'
]
centers
=
boxes3d
[:,
[
0
,
1
]]
boxes
=
torch
.
cat
([
centers
,
scores
.
view
(
-
1
,
1
)],
dim
=
1
)
keep
=
torch
.
tensor
(
circle_nms
(
boxes
.
detach
().
cpu
().
numpy
(),
self
.
test_cfg
[
'min_radius'
][
task_id
],
post_max_size
=
self
.
test_cfg
[
'post_max_size'
]),
dtype
=
torch
.
long
,
device
=
boxes
.
device
)
boxes3d
=
boxes3d
[
keep
]
scores
=
scores
[
keep
]
labels
=
labels
[
keep
]
ret
=
dict
(
bboxes
=
boxes3d
,
scores
=
scores
,
labels
=
labels
)
ret_task
.
append
(
ret
)
rets
.
append
(
ret_task
)
else
:
rets
.
append
(
self
.
get_task_detections
(
batch_cls_preds
,
batch_reg_preds
,
batch_cls_labels
,
img_metas
,
task_id
))
# rets: List[ret_task0, ret_task1, ...], len = num_tasks
# ret_task: List[p_dict0, p_dict1, ...], len = batch_size
# p_dict: dict{
# bboxes: (K', 9)
# scores: (K', )
# labels: (K', )
# }
# Merge branches results
num_samples
=
len
(
rets
[
0
])
# bs
ret_list
=
[]
# 遍历batch, 然后汇总所有task的预测.
for
i
in
range
(
num_samples
):
for
k
in
rets
[
0
][
i
].
keys
():
if
k
==
'bboxes'
:
bboxes
=
torch
.
cat
([
ret
[
i
][
k
]
for
ret
in
rets
])
# 对于bboxes, 直接拼接即可.
bboxes
[:,
2
]
=
bboxes
[:,
2
]
-
bboxes
[:,
5
]
*
0.5
bboxes
=
img_metas
[
i
][
'box_type_3d'
](
bboxes
,
self
.
bbox_coder
.
code_size
)
elif
k
==
'scores'
:
scores
=
torch
.
cat
([
ret
[
i
][
k
]
for
ret
in
rets
])
# 对于scores, 直接拼接即可.
elif
k
==
'labels'
:
flag
=
0
for
j
,
num_class
in
enumerate
(
self
.
num_classes
):
# 对于labels, 要进行调整, 因为预测的label是task组内的.
rets
[
j
][
i
][
k
]
+=
flag
flag
+=
num_class
labels
=
torch
.
cat
([
ret
[
i
][
k
].
int
()
for
ret
in
rets
])
ret_list
.
append
([
bboxes
,
scores
,
labels
])
# ret_list: List[p_list0, p_list1, ...]
# p_list: List[(N, 9), (N, ), (N, )]
return
ret_list
def
_nms
(
self
,
heat
,
kernel
=
3
):
pad
=
(
kernel
-
1
)
//
2
hmax
=
nn
.
functional
.
max_pool2d
(
heat
,
(
kernel
,
kernel
),
stride
=
1
,
padding
=
pad
)
keep
=
(
hmax
==
heat
).
float
()
return
heat
*
keep
def
get_centers
(
self
,
preds_dicts
,
img_metas
,
img
=
None
,
rescale
=
False
):
rets
=
[]
for
task_id
,
preds_dict
in
enumerate
(
preds_dicts
):
batch_size
=
preds_dict
[
0
][
'heatmap'
].
shape
[
0
]
batch_heatmap
=
preds_dict
[
0
][
'heatmap'
].
sigmoid
()
# (B, n_cls, H, W)
batch_reg
=
preds_dict
[
0
][
'reg'
]
# (B, 2, H, W)
batch_hei
=
preds_dict
[
0
][
'height'
]
# (B, 1, H, W)
batch_heatmap
=
self
.
_nms
(
batch_heatmap
)
temp
=
self
.
bbox_coder
.
center_decode
(
batch_heatmap
,
batch_hei
,
reg
=
batch_reg
,
task_id
=
task_id
)
batch_reg_preds
=
[
box
[
'centers'
]
for
box
in
temp
]
# List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds
=
[
box
[
'scores'
]
for
box
in
temp
]
# List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels
=
[
box
[
'labels'
]
for
box
in
temp
]
# List[(K0, ), (K1, ), ...] len = bs
ret_list
=
[
batch_reg_preds
,
batch_cls_preds
,
batch_cls_labels
]
return
ret_list
def
get_task_detections
(
self
,
batch_cls_preds
,
batch_reg_preds
,
batch_cls_labels
,
img_metas
,
task_id
):
"""Rotate nms for each task.
Args:
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9]. # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
List[p_dict0, p_dict1, ...] len = batch_size
p_dict: dict{
bboxes: (K', 9)
scores: (K', )
labels: (K', )
}
"""
predictions_dicts
=
[]
# 遍历不同batch的topK预测输出.
for
i
,
(
box_preds
,
cls_preds
,
cls_labels
)
in
enumerate
(
zip
(
batch_reg_preds
,
batch_cls_preds
,
batch_cls_labels
)):
# box_preds: (K, 9)
# cls_preds: (K, )
# cls_labels: (K, )
default_val
=
[
1.0
for
_
in
range
(
len
(
self
.
task_heads
))]
factor
=
self
.
test_cfg
.
get
(
'nms_rescale_factor'
,
default_val
)[
task_id
]
if
isinstance
(
factor
,
list
):
# List[float, float, ..] len = 当前task负责预测的类别数.
# 对于box_preds, 使用其对应的factor进行缩放, 一般是放大小目标,缩小大目标.
for
cid
in
range
(
len
(
factor
)):
box_preds
[
cls_labels
==
cid
,
3
:
6
]
=
\
box_preds
[
cls_labels
==
cid
,
3
:
6
]
*
factor
[
cid
]
else
:
box_preds
[:,
3
:
6
]
=
box_preds
[:,
3
:
6
]
*
factor
# Apply NMS in birdeye view
top_labels
=
cls_labels
.
long
()
# (K, )
top_scores
=
cls_preds
.
squeeze
(
-
1
)
if
cls_preds
.
shape
[
0
]
>
1
\
else
cls_preds
# (K, )
if
top_scores
.
shape
[
0
]
!=
0
:
boxes_for_nms
=
img_metas
[
i
][
'box_type_3d'
](
box_preds
[:,
:],
self
.
bbox_coder
.
code_size
).
bev
# (K, 5) (x, y, dx, dy, yaw)
# the nms in 3d detection just remove overlap boxes.
if
isinstance
(
self
.
test_cfg
[
'nms_thr'
],
list
):
nms_thresh
=
self
.
test_cfg
[
'nms_thr'
][
task_id
]
else
:
nms_thresh
=
self
.
test_cfg
[
'nms_thr'
]
selected
=
nms_bev
(
boxes_for_nms
,
top_scores
,
thresh
=
nms_thresh
,
pre_max_size
=
self
.
test_cfg
[
'pre_max_size'
],
post_max_size
=
self
.
test_cfg
[
'post_max_size'
],
xyxyr2xywhr
=
False
,
)
else
:
selected
=
[]
# NMS后再根据factor缩放回原来的尺寸.
if
isinstance
(
factor
,
list
):
for
cid
in
range
(
len
(
factor
)):
box_preds
[
top_labels
==
cid
,
3
:
6
]
=
\
box_preds
[
top_labels
==
cid
,
3
:
6
]
/
factor
[
cid
]
else
:
box_preds
[:,
3
:
6
]
=
box_preds
[:,
3
:
6
]
/
factor
# if selected is not None:
selected_boxes
=
box_preds
[
selected
]
# (K', 9)
selected_labels
=
top_labels
[
selected
]
# (K', )
selected_scores
=
top_scores
[
selected
]
# (K', )
# finally generate predictions.
if
selected_boxes
.
shape
[
0
]
!=
0
:
predictions_dict
=
dict
(
bboxes
=
selected_boxes
,
scores
=
selected_scores
,
labels
=
selected_labels
)
else
:
dtype
=
batch_reg_preds
[
0
].
dtype
device
=
batch_reg_preds
[
0
].
device
predictions_dict
=
dict
(
bboxes
=
torch
.
zeros
([
0
,
self
.
bbox_coder
.
code_size
],
dtype
=
dtype
,
device
=
device
),
scores
=
torch
.
zeros
([
0
],
dtype
=
dtype
,
device
=
device
),
labels
=
torch
.
zeros
([
0
],
dtype
=
top_labels
.
dtype
,
device
=
device
))
predictions_dicts
.
append
(
predictions_dict
)
return
predictions_dicts
projects/mmdet3d_plugin/models/dense_heads/bev_occ_head.py
0 → 100644
View file @
3b8d508a
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch
import
nn
from
mmcv.cnn
import
ConvModule
from
mmcv.runner
import
BaseModule
import
numpy
as
np
from
mmdet3d.models.builder
import
HEADS
,
build_loss
from
..losses.semkitti_loss
import
sem_scal_loss
,
geo_scal_loss
from
..losses.lovasz_softmax
import
lovasz_softmax
nusc_class_frequencies
=
np
.
array
([
944004
,
1897170
,
152386
,
2391677
,
16957802
,
724139
,
189027
,
2074468
,
413451
,
2384460
,
5916653
,
175883646
,
4275424
,
51393615
,
61411620
,
105975596
,
116424404
,
1892500630
])
@
HEADS
.
register_module
()
class
BEVOCCHead3D
(
BaseModule
):
def
__init__
(
self
,
in_dim
=
32
,
out_dim
=
32
,
use_mask
=
True
,
num_classes
=
18
,
use_predicter
=
True
,
class_balance
=
False
,
loss_occ
=
None
):
super
(
BEVOCCHead3D
,
self
).
__init__
()
self
.
out_dim
=
32
out_channels
=
out_dim
if
use_predicter
else
num_classes
self
.
final_conv
=
ConvModule
(
in_dim
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
conv_cfg
=
dict
(
type
=
'Conv3d'
)
)
self
.
use_predicter
=
use_predicter
if
use_predicter
:
self
.
predicter
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_dim
,
self
.
out_dim
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
out_dim
*
2
,
num_classes
),
)
self
.
num_classes
=
num_classes
self
.
use_mask
=
use_mask
self
.
class_balance
=
class_balance
if
self
.
class_balance
:
class_weights
=
torch
.
from_numpy
(
1
/
np
.
log
(
nusc_class_frequencies
[:
num_classes
]
+
0.001
))
self
.
cls_weights
=
class_weights
loss_occ
[
'class_weight'
]
=
class_weights
self
.
loss_occ
=
build_loss
(
loss_occ
)
def
forward
(
self
,
img_feats
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx)
Returns:
"""
# (B, C, Dz, Dy, Dx) --> (B, C, Dz, Dy, Dx) --> (B, Dx, Dy, Dz, C)
occ_pred
=
self
.
final_conv
(
img_feats
).
permute
(
0
,
4
,
3
,
2
,
1
)
if
self
.
use_predicter
:
# (B, Dx, Dy, Dz, C) --> (B, Dx, Dy, Dz, 2*C) --> (B, Dx, Dy, Dz, n_cls)
occ_pred
=
self
.
predicter
(
occ_pred
)
return
occ_pred
def
loss
(
self
,
occ_pred
,
voxel_semantics
,
mask_camera
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss
=
dict
()
voxel_semantics
=
voxel_semantics
.
long
()
if
self
.
use_mask
:
mask_camera
=
mask_camera
.
to
(
torch
.
int32
)
# (B, Dx, Dy, Dz)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
# (B, Dx, Dy, Dz, n_cls) --> (B*Dx*Dy*Dz, n_cls)
preds
=
occ_pred
.
reshape
(
-
1
,
self
.
num_classes
)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
mask_camera
=
mask_camera
.
reshape
(
-
1
)
if
self
.
class_balance
:
valid_voxels
=
voxel_semantics
[
mask_camera
.
bool
()]
num_total_samples
=
0
for
i
in
range
(
self
.
num_classes
):
num_total_samples
+=
(
valid_voxels
==
i
).
sum
()
*
self
.
cls_weights
[
i
]
else
:
num_total_samples
=
mask_camera
.
sum
()
loss_occ
=
self
.
loss_occ
(
preds
,
# (B*Dx*Dy*Dz, n_cls)
voxel_semantics
,
# (B*Dx*Dy*Dz, )
mask_camera
,
# (B*Dx*Dy*Dz, )
avg_factor
=
num_total_samples
)
else
:
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
preds
=
occ_pred
.
reshape
(
-
1
,
self
.
num_classes
)
if
self
.
class_balance
:
num_total_samples
=
0
for
i
in
range
(
self
.
num_classes
):
num_total_samples
+=
(
voxel_semantics
==
i
).
sum
()
*
self
.
cls_weights
[
i
]
else
:
num_total_samples
=
len
(
voxel_semantics
)
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,
avg_factor
=
num_total_samples
)
loss
[
'loss_occ'
]
=
loss_occ
return
loss
def
get_occ
(
self
,
occ_pred
,
img_metas
=
None
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score
=
occ_pred
.
softmax
(
-
1
)
# (B, Dx, Dy, Dz, C)
occ_res
=
occ_score
.
argmax
(
-
1
)
# (B, Dx, Dy, Dz)
occ_res
=
occ_res
.
cpu
().
numpy
().
astype
(
np
.
uint8
)
# (B, Dx, Dy, Dz)
return
list
(
occ_res
)
@
HEADS
.
register_module
()
class
BEVOCCHead2D
(
BaseModule
):
def
__init__
(
self
,
in_dim
=
256
,
out_dim
=
256
,
Dz
=
16
,
use_mask
=
True
,
num_classes
=
18
,
use_predicter
=
True
,
class_balance
=
False
,
loss_occ
=
None
,
):
super
(
BEVOCCHead2D
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
Dz
=
Dz
out_channels
=
out_dim
if
use_predicter
else
num_classes
*
Dz
self
.
final_conv
=
ConvModule
(
self
.
in_dim
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
conv_cfg
=
dict
(
type
=
'Conv2d'
)
)
self
.
use_predicter
=
use_predicter
if
use_predicter
:
self
.
predicter
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_dim
,
self
.
out_dim
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
out_dim
*
2
,
num_classes
*
Dz
),
)
self
.
use_mask
=
use_mask
self
.
num_classes
=
num_classes
self
.
class_balance
=
class_balance
if
self
.
class_balance
:
class_weights
=
torch
.
from_numpy
(
1
/
np
.
log
(
nusc_class_frequencies
[:
num_classes
]
+
0.001
))
self
.
cls_weights
=
class_weights
loss_occ
[
'class_weight'
]
=
class_weights
# ce loss
self
.
loss_occ
=
build_loss
(
loss_occ
)
#@torch.compiler.disable
def
forward
(
self
,
img_feats
):
"""
Args:
img_feats: (B, C, Dy, Dx)
Returns:
"""
# (B, C, Dy, Dx) --> (B, C, Dy, Dx) --> (B, Dx, Dy, C)
occ_pred
=
self
.
final_conv
(
img_feats
).
permute
(
0
,
3
,
2
,
1
)
bs
,
Dx
,
Dy
=
occ_pred
.
shape
[:
3
]
if
self
.
use_predicter
:
# (B, Dx, Dy, C) --> (B, Dx, Dy, 2*C) --> (B, Dx, Dy, Dz*n_cls)
occ_pred
=
self
.
predicter
(
occ_pred
)
occ_pred
=
occ_pred
.
view
(
bs
,
Dx
,
Dy
,
self
.
Dz
,
self
.
num_classes
)
return
occ_pred
def
loss
(
self
,
occ_pred
,
voxel_semantics
,
mask_camera
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss
=
dict
()
voxel_semantics
=
voxel_semantics
.
long
()
if
self
.
use_mask
:
mask_camera
=
mask_camera
.
to
(
torch
.
int32
)
# (B, Dx, Dy, Dz)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
# (B, Dx, Dy, Dz, n_cls) --> (B*Dx*Dy*Dz, n_cls)
preds
=
occ_pred
.
reshape
(
-
1
,
self
.
num_classes
)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
mask_camera
=
mask_camera
.
reshape
(
-
1
)
if
self
.
class_balance
:
valid_voxels
=
voxel_semantics
[
mask_camera
.
bool
()]
num_total_samples
=
0
for
i
in
range
(
self
.
num_classes
):
num_total_samples
+=
(
valid_voxels
==
i
).
sum
()
*
self
.
cls_weights
[
i
]
else
:
num_total_samples
=
mask_camera
.
sum
()
loss_occ
=
self
.
loss_occ
(
preds
,
# (B*Dx*Dy*Dz, n_cls)
voxel_semantics
,
# (B*Dx*Dy*Dz, )
mask_camera
,
# (B*Dx*Dy*Dz, )
avg_factor
=
num_total_samples
)
loss
[
'loss_occ'
]
=
loss_occ
else
:
voxel_semantics
=
voxel_semantics
.
reshape
(
-
1
)
preds
=
occ_pred
.
reshape
(
-
1
,
self
.
num_classes
)
if
self
.
class_balance
:
num_total_samples
=
0
for
i
in
range
(
self
.
num_classes
):
num_total_samples
+=
(
voxel_semantics
==
i
).
sum
()
*
self
.
cls_weights
[
i
]
else
:
num_total_samples
=
len
(
voxel_semantics
)
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,
avg_factor
=
num_total_samples
)
loss
[
'loss_occ'
]
=
loss_occ
return
loss
def
get_occ
(
self
,
occ_pred
,
img_metas
=
None
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score
=
occ_pred
.
softmax
(
-
1
)
# (B, Dx, Dy, Dz, C)
occ_res
=
occ_score
.
argmax
(
-
1
)
# (B, Dx, Dy, Dz)
occ_res
=
occ_res
.
cpu
().
numpy
().
astype
(
np
.
uint8
)
# (B, Dx, Dy, Dz)
return
list
(
occ_res
)
@
HEADS
.
register_module
()
class
BEVOCCHead2D_V2
(
BaseModule
):
# Use stronger loss setting
def
__init__
(
self
,
in_dim
=
256
,
out_dim
=
256
,
Dz
=
16
,
use_mask
=
True
,
num_classes
=
18
,
use_predicter
=
True
,
class_balance
=
False
,
loss_occ
=
None
,
):
super
(
BEVOCCHead2D_V2
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
Dz
=
Dz
# voxel-level prediction
self
.
occ_convs
=
nn
.
ModuleList
()
self
.
final_conv
=
ConvModule
(
in_dim
,
self
.
out_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
conv_cfg
=
dict
(
type
=
'Conv2d'
)
)
self
.
use_predicter
=
use_predicter
if
use_predicter
:
self
.
predicter
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_dim
,
self
.
out_dim
*
2
),
nn
.
Softplus
(),
nn
.
Linear
(
self
.
out_dim
*
2
,
num_classes
*
Dz
),
)
self
.
use_mask
=
use_mask
self
.
num_classes
=
num_classes
self
.
class_balance
=
class_balance
if
self
.
class_balance
:
class_weights
=
torch
.
from_numpy
(
1
/
np
.
log
(
nusc_class_frequencies
[:
num_classes
]
+
0.001
))
self
.
cls_weights
=
class_weights
self
.
loss_occ
=
build_loss
(
loss_occ
)
def
forward
(
self
,
img_feats
):
"""
Args:
img_feats: (B, C, Dy=200, Dx=200)
img_feats: [(B, C, 100, 100), (B, C, 50, 50), (B, C, 25, 25)] if ms
Returns:
"""
# (B, C, Dy, Dx) --> (B, C, Dy, Dx) --> (B, Dx, Dy, C)
occ_pred
=
self
.
final_conv
(
img_feats
).
permute
(
0
,
3
,
2
,
1
)
bs
,
Dx
,
Dy
=
occ_pred
.
shape
[:
3
]
if
self
.
use_predicter
:
# (B, Dx, Dy, C) --> (B, Dx, Dy, 2*C) --> (B, Dx, Dy, Dz*n_cls)
occ_pred
=
self
.
predicter
(
occ_pred
)
occ_pred
=
occ_pred
.
view
(
bs
,
Dx
,
Dy
,
self
.
Dz
,
self
.
num_classes
)
return
occ_pred
def
loss
(
self
,
occ_pred
,
voxel_semantics
,
mask_camera
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss
=
dict
()
voxel_semantics
=
voxel_semantics
.
long
()
# (B, Dx, Dy, Dz)
preds
=
occ_pred
.
permute
(
0
,
4
,
1
,
2
,
3
).
contiguous
()
# (B, n_cls, Dx, Dy, Dz)
loss_occ
=
self
.
loss_occ
(
preds
,
voxel_semantics
,
weight
=
self
.
cls_weights
.
to
(
preds
),
)
*
100.0
loss
[
'loss_occ'
]
=
loss_occ
loss
[
'loss_voxel_sem_scal'
]
=
sem_scal_loss
(
preds
,
voxel_semantics
)
loss
[
'loss_voxel_geo_scal'
]
=
geo_scal_loss
(
preds
,
voxel_semantics
,
non_empty_idx
=
17
)
loss
[
'loss_voxel_lovasz'
]
=
lovasz_softmax
(
torch
.
softmax
(
preds
,
dim
=
1
),
voxel_semantics
)
return
loss
def
get_occ
(
self
,
occ_pred
,
img_metas
=
None
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score
=
occ_pred
.
softmax
(
-
1
)
# (B, Dx, Dy, Dz, C)
occ_res
=
occ_score
.
argmax
(
-
1
)
# (B, Dx, Dy, Dz)
occ_res
=
occ_res
.
cpu
().
numpy
().
astype
(
np
.
uint8
)
# (B, Dx, Dy, Dz)
return
list
(
occ_res
)
def
get_occ_gpu
(
self
,
occ_pred
,
img_metas
=
None
):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score
=
occ_pred
.
softmax
(
-
1
)
# (B, Dx, Dy, Dz, C)
occ_res
=
occ_score
.
argmax
(
-
1
).
int
()
# (B, Dx, Dy, Dz)
return
list
(
occ_res
)
projects/mmdet3d_plugin/models/detectors/__init__.py
0 → 100644
View file @
3b8d508a
from
.bevdet
import
BEVDet
from
.bevdepth
import
BEVDepth
from
.bevdet4d
import
BEVDet4D
from
.bevdepth4d
import
BEVDepth4D
from
.bevstereo4d
import
BEVStereo4D
from
.bevdet_occ
import
BEVDetOCC
,
BEVDepthOCC
,
BEVDepth4DOCC
,
BEVStereo4DOCC
,
BEVDepth4DPano
,
BEVDepthPano
,
BEVDepthPanoTRT
__all__
=
[
'BEVDet'
,
'BEVDepth'
,
'BEVDet4D'
,
'BEVDepth4D'
,
'BEVStereo4D'
,
'BEVDetOCC'
,
'BEVDepthOCC'
,
'BEVDepth4DOCC'
,
'BEVStereo4DOCC'
,
'BEVDepthPano'
,
'BEVDepth4DPano'
,
'BEVDepthPanoTRT'
]
\ No newline at end of file
projects/mmdet3d_plugin/models/detectors/bevdepth.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
.bevdet
import
BEVDet
from
mmdet3d.models
import
builder
@
DETECTORS
.
register_module
()
class
BEVDepth
(
BEVDet
):
def
__init__
(
self
,
img_backbone
,
img_neck
,
img_view_transformer
,
img_bev_encoder_backbone
,
img_bev_encoder_neck
,
pts_bbox_head
=
None
,
**
kwargs
):
super
(
BEVDepth
,
self
).
__init__
(
img_backbone
=
img_backbone
,
img_neck
=
img_neck
,
img_view_transformer
=
img_view_transformer
,
img_bev_encoder_backbone
=
img_bev_encoder_backbone
,
img_bev_encoder_neck
=
img_bev_encoder_neck
,
pts_bbox_head
=
pts_bbox_head
)
def
image_encoder
(
self
,
img
,
stereo
=
False
):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs
=
img
B
,
N
,
C
,
imH
,
imW
=
imgs
.
shape
imgs
=
imgs
.
view
(
B
*
N
,
C
,
imH
,
imW
)
x
=
self
.
img_backbone
(
imgs
)
stereo_feat
=
None
if
stereo
:
stereo_feat
=
x
[
0
]
x
=
x
[
1
:]
if
self
.
with_img_neck
:
x
=
self
.
img_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
_
,
output_dim
,
ouput_H
,
output_W
=
x
.
shape
x
=
x
.
view
(
B
,
N
,
output_dim
,
ouput_H
,
output_W
)
return
x
,
stereo_feat
@
force_fp32
()
def
bev_encoder
(
self
,
x
):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x
=
self
.
img_bev_encoder_backbone
(
x
)
x
=
self
.
img_bev_encoder_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
return
x
def
prepare_inputs
(
self
,
inputs
):
# split the inputs into each frame
assert
len
(
inputs
)
==
7
B
,
N
,
C
,
H
,
W
=
inputs
[
0
].
shape
imgs
,
sensor2egos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
\
inputs
sensor2egos
=
sensor2egos
.
view
(
B
,
N
,
4
,
4
)
ego2globals
=
ego2globals
.
view
(
B
,
N
,
4
,
4
)
# calculate the transformation from adj sensor to key ego
keyego2global
=
ego2globals
[:,
0
,
...].
unsqueeze
(
1
)
# (B, 1, 4, 4)
global2keyego
=
torch
.
inverse
(
keyego2global
.
double
())
# (B, 1, 4, 4)
sensor2keyegos
=
\
global2keyego
@
ego2globals
.
double
()
@
sensor2egos
.
double
()
# (B, N_views, 4, 4)
sensor2keyegos
=
sensor2keyegos
.
float
()
return
[
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
]
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
**
kwargs
):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
self
.
prepare_inputs
(
img_inputs
)
x
,
_
=
self
.
image_encoder
(
imgs
)
# x: (B, N, C, fH, fW)
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
)
# (B, N_views, 27)
x
,
depth
=
self
.
img_view_transformer
([
x
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
,
mlp_input
])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x
=
self
.
bev_encoder
(
x
)
return
[
x
],
depth
def
extract_feat
(
self
,
points
,
img_inputs
,
img_metas
,
**
kwargs
):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats
,
depth
=
self
.
extract_img_feat
(
img_inputs
,
img_metas
,
**
kwargs
)
pts_feats
=
None
return
img_feats
,
pts_feats
,
depth
def
forward_train
(
self
,
points
=
None
,
img_inputs
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
img_metas
=
None
,
gt_bboxes
=
None
,
gt_labels
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
=
dict
(
loss_depth
=
loss_depth
)
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
points
=
None
,
img_inputs
=
None
,
img_metas
=
None
,
**
kwargs
):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for
var
,
name
in
[(
img_inputs
,
'img_inputs'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
num_augs
=
len
(
img_inputs
)
if
num_augs
!=
len
(
img_metas
):
raise
ValueError
(
'num of augmentations ({}) != num of image meta ({})'
.
format
(
len
(
img_inputs
),
len
(
img_metas
)))
if
not
isinstance
(
img_inputs
[
0
][
0
],
list
):
img_inputs
=
[
img_inputs
]
if
img_inputs
is
None
else
img_inputs
points
=
[
points
]
if
points
is
None
else
points
return
self
.
simple_test
(
points
[
0
],
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
else
:
return
self
.
aug_test
(
None
,
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
def
aug_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
assert
False
def
simple_test
(
self
,
points
,
img_metas
,
img_inputs
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
bbox_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
rescale
=
rescale
)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
bbox_list
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
assert
self
.
with_pts_bbox
outs
=
self
.
pts_bbox_head
(
img_feats
)
return
outs
\ No newline at end of file
projects/mmdet3d_plugin/models/detectors/bevdepth4d.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
builder
from
.bevdet4d
import
BEVDet4D
@
DETECTORS
.
register_module
()
class
BEVDepth4D
(
BEVDet4D
):
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
=
dict
(
loss_depth
=
loss_depth
)
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_pts
)
return
losses
\ No newline at end of file
projects/mmdet3d_plugin/models/detectors/bevdet.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
CenterPoint
from
mmdet3d.models
import
builder
@
DETECTORS
.
register_module
()
class
BEVDet
(
CenterPoint
):
def
__init__
(
self
,
img_backbone
,
img_neck
,
img_view_transformer
,
img_bev_encoder_backbone
,
img_bev_encoder_neck
,
pts_bbox_head
=
None
,
**
kwargs
):
super
(
BEVDet
,
self
).
__init__
(
img_backbone
=
img_backbone
,
img_neck
=
img_neck
,
pts_bbox_head
=
pts_bbox_head
,
**
kwargs
)
self
.
img_view_transformer
=
builder
.
build_neck
(
img_view_transformer
)
self
.
img_bev_encoder_backbone
=
builder
.
build_backbone
(
img_bev_encoder_backbone
)
self
.
img_bev_encoder_neck
=
builder
.
build_neck
(
img_bev_encoder_neck
)
def
image_encoder
(
self
,
img
,
stereo
=
False
):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs
=
img
B
,
N
,
C
,
imH
,
imW
=
imgs
.
shape
imgs
=
imgs
.
view
(
B
*
N
,
C
,
imH
,
imW
)
#imgs = imgs.to(memory_format=torch.channels_last)
x
=
self
.
img_backbone
(
imgs
)
stereo_feat
=
None
if
stereo
:
stereo_feat
=
x
[
0
]
x
=
x
[
1
:]
if
self
.
with_img_neck
:
x
=
self
.
img_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
_
,
output_dim
,
ouput_H
,
output_W
=
x
.
shape
x
=
x
.
view
(
B
,
N
,
output_dim
,
ouput_H
,
output_W
)
return
x
,
stereo_feat
@
force_fp32
()
def
bev_encoder
(
self
,
x
):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x
=
self
.
img_bev_encoder_backbone
(
x
)
x
=
self
.
img_bev_encoder_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
return
x
def
prepare_inputs
(
self
,
inputs
):
# split the inputs into each frame
assert
len
(
inputs
)
==
7
B
,
N
,
C
,
H
,
W
=
inputs
[
0
].
shape
imgs
,
sensor2egos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
\
inputs
sensor2egos
=
sensor2egos
.
view
(
B
,
N
,
4
,
4
)
ego2globals
=
ego2globals
.
view
(
B
,
N
,
4
,
4
)
# calculate the transformation from adj sensor to key ego
keyego2global
=
ego2globals
[:,
0
,
...].
unsqueeze
(
1
)
# (B, 1, 4, 4)
global2keyego
=
torch
.
inverse
(
keyego2global
.
double
())
# (B, 1, 4, 4)
sensor2keyegos
=
\
global2keyego
@
ego2globals
.
double
()
@
sensor2egos
.
double
()
# (B, N_views, 4, 4)
sensor2keyegos
=
sensor2keyegos
.
float
()
return
[
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
]
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
**
kwargs
):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
img_inputs
=
self
.
prepare_inputs
(
img_inputs
)
x
,
_
=
self
.
image_encoder
(
img_inputs
[
0
])
# x: (B, N, C, fH, fW)
x
,
depth
=
self
.
img_view_transformer
([
x
]
+
img_inputs
[
1
:
7
])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x
=
self
.
bev_encoder
(
x
)
return
[
x
],
depth
@
torch
.
compile
def
extract_feat
(
self
,
points
,
img_inputs
,
img_metas
,
**
kwargs
):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats
,
depth
=
self
.
extract_img_feat
(
img_inputs
,
img_metas
,
**
kwargs
)
pts_feats
=
None
return
img_feats
,
pts_feats
,
depth
def
forward_train
(
self
,
points
=
None
,
img_inputs
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
img_metas
=
None
,
gt_bboxes
=
None
,
gt_labels
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
,
pts_feats
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
points
=
None
,
img_inputs
=
None
,
img_metas
=
None
,
**
kwargs
):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for
var
,
name
in
[(
img_inputs
,
'img_inputs'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
num_augs
=
len
(
img_inputs
)
if
num_augs
!=
len
(
img_metas
):
raise
ValueError
(
'num of augmentations ({}) != num of image meta ({})'
.
format
(
len
(
img_inputs
),
len
(
img_metas
)))
if
not
isinstance
(
img_inputs
[
0
][
0
],
list
):
img_inputs
=
[
img_inputs
]
if
img_inputs
is
None
else
img_inputs
points
=
[
points
]
if
points
is
None
else
points
return
self
.
simple_test
(
points
[
0
],
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
else
:
return
self
.
aug_test
(
None
,
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
def
aug_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
assert
False
def
simple_test
(
self
,
points
,
img_metas
,
img_inputs
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
bbox_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
rescale
=
rescale
)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
bbox_list
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
assert
self
.
with_pts_bbox
outs
=
self
.
pts_bbox_head
(
img_feats
)
return
outs
projects/mmdet3d_plugin/models/detectors/bevdet4d.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
builder
from
.bevdet
import
BEVDet
@
DETECTORS
.
register_module
()
class
BEVDet4D
(
BEVDet
):
r
"""BEVDet4D paradigm for multi-camera 3D object detection.
Please refer to the `paper <https://arxiv.org/abs/2203.17054>`_
Args:
pre_process (dict | None): Configuration dict of BEV pre-process net.
align_after_view_transfromation (bool): Whether to align the BEV
Feature after view transformation. By default, the BEV feature of
the previous frame is aligned during the view transformation.
num_adj (int): Number of adjacent frames.
with_prev (bool): Whether to set the BEV feature of previous frame as
all zero. By default, False.
"""
def
__init__
(
self
,
pre_process
=
None
,
align_after_view_transfromation
=
False
,
num_adj
=
1
,
with_prev
=
True
,
**
kwargs
):
super
(
BEVDet4D
,
self
).
__init__
(
**
kwargs
)
self
.
pre_process
=
pre_process
is
not
None
if
self
.
pre_process
:
self
.
pre_process_net
=
builder
.
build_backbone
(
pre_process
)
self
.
align_after_view_transfromation
=
align_after_view_transfromation
self
.
num_frame
=
num_adj
+
1
self
.
with_prev
=
with_prev
self
.
grid
=
None
def
gen_grid
(
self
,
input
,
sensor2keyegos
,
bda
,
bda_adj
=
None
):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
grid: (B, Dy, Dx, 2)
"""
B
,
C
,
H
,
W
=
input
.
shape
v
=
sensor2keyegos
[
0
].
shape
[
0
]
# N_views
if
self
.
grid
is
None
:
# generate grid
xs
=
torch
.
linspace
(
0
,
W
-
1
,
W
,
dtype
=
input
.
dtype
,
device
=
input
.
device
).
view
(
1
,
W
).
expand
(
H
,
W
)
# (Dy, Dx)
ys
=
torch
.
linspace
(
0
,
H
-
1
,
H
,
dtype
=
input
.
dtype
,
device
=
input
.
device
).
view
(
H
,
1
).
expand
(
H
,
W
)
# (Dy, Dx)
grid
=
torch
.
stack
((
xs
,
ys
,
torch
.
ones_like
(
xs
)),
-
1
)
# (Dy, Dx, 3) 3: (x, y, 1)
self
.
grid
=
grid
else
:
grid
=
self
.
grid
# (Dy, Dx, 3) --> (1, Dy, Dx, 3) --> (B, Dy, Dx, 3) --> (B, Dy, Dx, 3, 1)) 3: (grid_x, grid_y, 1)
grid
=
grid
.
view
(
1
,
H
,
W
,
3
).
expand
(
B
,
H
,
W
,
3
).
view
(
B
,
H
,
W
,
3
,
1
)
curr_sensor2keyego
=
sensor2keyegos
[
0
][:,
0
:
1
,
:,
:]
# (B, 1, 4, 4)
prev_sensor2keyego
=
sensor2keyegos
[
1
][:,
0
:
1
,
:,
:]
# (B, 1, 4, 4)
# add bev data augmentation
bda_
=
torch
.
zeros
((
B
,
1
,
4
,
4
),
dtype
=
grid
.
dtype
).
to
(
grid
)
# (B, 1, 4, 4)
bda_
[:,
:,
:
3
,
:
3
]
=
bda
.
unsqueeze
(
1
)
bda_
[:,
:,
3
,
3
]
=
1
curr_sensor2keyego
=
bda_
.
matmul
(
curr_sensor2keyego
)
# (B, 1, 4, 4)
if
bda_adj
is
not
None
:
bda_
=
torch
.
zeros
((
B
,
1
,
4
,
4
),
dtype
=
grid
.
dtype
).
to
(
grid
)
bda_
[:,
:,
:
3
,
:
3
]
=
bda_adj
.
unsqueeze
(
1
)
bda_
[:,
:,
3
,
3
]
=
1
prev_sensor2keyego
=
bda_
.
matmul
(
prev_sensor2keyego
)
# (B, 1, 4, 4)
# transformation from current ego frame to adjacent ego frame
# key_ego --> prev_cam_front --> prev_ego
keyego2adjego
=
curr_sensor2keyego
.
matmul
(
torch
.
inverse
(
prev_sensor2keyego
))
keyego2adjego
=
keyego2adjego
.
unsqueeze
(
dim
=
1
)
# (B, 1, 1, 4, 4)
# (B, 1, 1, 3, 3)
keyego2adjego
=
keyego2adjego
[...,
[
True
,
True
,
False
,
True
],
:][...,
[
True
,
True
,
False
,
True
]]
# x = grid_x * vx + x_min; y = grid_y * vy + y_min;
# feat2bev:
# [[vx, 0, x_min],
# [0, vy, y_min],
# [0, 0, 1 ]]
feat2bev
=
torch
.
zeros
((
3
,
3
),
dtype
=
grid
.
dtype
).
to
(
grid
)
feat2bev
[
0
,
0
]
=
self
.
img_view_transformer
.
grid_interval
[
0
]
feat2bev
[
1
,
1
]
=
self
.
img_view_transformer
.
grid_interval
[
1
]
feat2bev
[
0
,
2
]
=
self
.
img_view_transformer
.
grid_lower_bound
[
0
]
feat2bev
[
1
,
2
]
=
self
.
img_view_transformer
.
grid_lower_bound
[
1
]
feat2bev
[
2
,
2
]
=
1
feat2bev
=
feat2bev
.
view
(
1
,
3
,
3
)
# (1, 3, 3)
# curr_feat_grid --> key ego --> prev_cam --> prev_ego --> prev_feat_grid
tf
=
torch
.
inverse
(
feat2bev
).
matmul
(
keyego2adjego
).
matmul
(
feat2bev
)
# (B, 1, 1, 3, 3)
grid
=
tf
.
matmul
(
grid
)
# (B, Dy, Dx, 3, 1) 3: (grid_x, grid_y, 1)
normalize_factor
=
torch
.
tensor
([
W
-
1.0
,
H
-
1.0
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
# (2, )
# (B, Dy, Dx, 2)
grid
=
grid
[:,
:,
:,
:
2
,
0
]
/
normalize_factor
.
view
(
1
,
1
,
1
,
2
)
*
2.0
-
1.0
return
grid
@
force_fp32
()
def
shift_feature
(
self
,
input
,
sensor2keyegos
,
bda
,
bda_adj
=
None
):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
output: aligned bev feat (B, C, Dy, Dx).
"""
grid
=
self
.
gen_grid
(
input
,
sensor2keyegos
,
bda
,
bda_adj
=
bda_adj
)
# grid: (B, Dy, Dx, 2), 介于(-1, 1)
output
=
F
.
grid_sample
(
input
,
grid
.
to
(
input
.
dtype
),
align_corners
=
True
)
# (B, C, Dy, Dx)
return
output
def
prepare_bev_feat
(
self
,
img
,
sensor2egos
,
ego2globals
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
):
"""
Args:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
mlp_input:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
"""
x
,
_
=
self
.
image_encoder
(
img
)
# x: (B, N, C, fH, fW)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat
,
depth
=
self
.
img_view_transformer
(
[
x
,
sensor2egos
,
ego2globals
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
])
if
self
.
pre_process
:
bev_feat
=
self
.
pre_process_net
(
bev_feat
)[
0
]
# (B, C, Dy, Dx)
return
bev_feat
,
depth
def
extract_img_feat_sequential
(
self
,
inputs
,
feat_prev
):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs
,
sensor2keyegos_curr
,
ego2globals_curr
,
intrins
=
inputs
[:
4
]
sensor2keyegos_prev
,
_
,
post_rots
,
post_trans
,
bda
=
inputs
[
4
:]
bev_feat_list
=
[]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...])
inputs_curr
=
(
imgs
,
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...],
mlp_input
)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat
,
depth
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
bev_feat_list
.
append
(
bev_feat
)
# align the feat_prev
_
,
C
,
H
,
W
=
feat_prev
.
shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev
=
\
self
.
shift_feature
(
feat_prev
,
# (N_prev, C, Dy, Dx)
[
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
],
# (N_prev, N_views, 4, 4)
bda
# (N_prev, 3, 3)
)
bev_feat_list
.
append
(
feat_prev
.
view
(
1
,
(
self
.
num_frame
-
1
)
*
C
,
H
,
W
))
# (1, N_prev*C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
# (1, N_frames*C, Dy, Dx)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth
def
prepare_inputs
(
self
,
img_inputs
,
stereo
=
False
):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
stereo: bool
Returns:
imgs: List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...] len = N_frames
sensor2keyegos: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
ego2globals: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
intrins: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_rots: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_trans: List[(B, N_views, 3), (B, N_views, 3), ...]
bda: (B, 3, 3)
"""
B
,
N
,
C
,
H
,
W
=
img_inputs
[
0
].
shape
N
=
N
//
self
.
num_frame
# N_views = 6
imgs
=
img_inputs
[
0
].
view
(
B
,
N
,
self
.
num_frame
,
C
,
H
,
W
)
# (B, N_views, N_frames, C, H, W)
imgs
=
torch
.
split
(
imgs
,
1
,
2
)
imgs
=
[
t
.
squeeze
(
2
)
for
t
in
imgs
]
# List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...]
sensor2egos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
\
img_inputs
[
1
:
7
]
sensor2egos
=
sensor2egos
.
view
(
B
,
self
.
num_frame
,
N
,
4
,
4
)
ego2globals
=
ego2globals
.
view
(
B
,
self
.
num_frame
,
N
,
4
,
4
)
# calculate the transformation from sensor to key ego
# key_ego --> global (B, 1, 1, 4, 4)
keyego2global
=
ego2globals
[:,
0
,
0
,
...].
unsqueeze
(
1
).
unsqueeze
(
1
)
# global --> key_ego (B, 1, 1, 4, 4)
global2keyego
=
torch
.
inverse
(
keyego2global
.
double
())
# sensor --> ego --> global --> key_ego
sensor2keyegos
=
\
global2keyego
@
ego2globals
.
double
()
@
sensor2egos
.
double
()
# (B, N_frames, N_views, 4, 4)
sensor2keyegos
=
sensor2keyegos
.
float
()
# -------------------- for stereo --------------------------
curr2adjsensor
=
None
if
stereo
:
# (B, N_frames, N_views, 4, 4), (B, N_frames, N_views, 4, 4)
sensor2egos_cv
,
ego2globals_cv
=
sensor2egos
,
ego2globals
sensor2egos_curr
=
\
sensor2egos_cv
[:,
:
self
.
temporal_frame
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
ego2globals_curr
=
\
ego2globals_cv
[:,
:
self
.
temporal_frame
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
sensor2egos_adj
=
\
sensor2egos_cv
[:,
1
:
self
.
temporal_frame
+
1
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
ego2globals_adj
=
\
ego2globals_cv
[:,
1
:
self
.
temporal_frame
+
1
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
# curr_sensor --> curr_ego --> global --> prev_ego --> prev_sensor
curr2adjsensor
=
\
torch
.
inverse
(
ego2globals_adj
@
sensor2egos_adj
)
\
@
ego2globals_curr
@
sensor2egos_curr
# (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor
=
curr2adjsensor
.
float
()
# (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor
=
torch
.
split
(
curr2adjsensor
,
1
,
1
)
curr2adjsensor
=
[
p
.
squeeze
(
1
)
for
p
in
curr2adjsensor
]
curr2adjsensor
.
extend
([
None
for
_
in
range
(
self
.
extra_ref_frames
)])
# curr2adjsensor: List[(B, N_views, 4, 4), (B, N_views, 4, 4), None]
assert
len
(
curr2adjsensor
)
==
self
.
num_frame
# -------------------- for stereo --------------------------
extra
=
[
sensor2keyegos
,
# (B, N_frames, N_views, 4, 4)
ego2globals
,
# (B, N_frames, N_views, 4, 4)
intrins
.
view
(
B
,
self
.
num_frame
,
N
,
3
,
3
),
# (B, N_frames, N_views, 3, 3)
post_rots
.
view
(
B
,
self
.
num_frame
,
N
,
3
,
3
),
# (B, N_frames, N_views, 3, 3)
post_trans
.
view
(
B
,
self
.
num_frame
,
N
,
3
)
# (B, N_frames, N_views, 3)
]
extra
=
[
torch
.
split
(
t
,
1
,
1
)
for
t
in
extra
]
extra
=
[[
p
.
squeeze
(
1
)
for
p
in
t
]
for
t
in
extra
]
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
=
extra
return
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
\
bda
,
curr2adjsensor
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
pred_prev
=
False
,
sequential
=
False
,
**
kwargs
):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if
sequential
:
return
self
.
extract_img_feat_sequential
(
img_inputs
,
kwargs
[
'feat_prev'
])
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
\
bda
,
_
=
self
.
prepare_inputs
(
img_inputs
)
"""Extract features of images."""
bev_feat_list
=
[]
depth_list
=
[]
key_frame
=
True
# back propagation for key frame only
for
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
in
zip
(
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
):
if
key_frame
or
self
.
with_prev
:
if
self
.
align_after_view_transfromation
:
sensor2keyego
,
ego2global
=
sensor2keyegos
[
0
],
ego2globals
[
0
]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos
[
0
],
ego2globals
[
0
],
intrin
,
post_rot
,
post_tran
,
bda
)
# (B, N_views, 27)
inputs_curr
=
(
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
)
if
key_frame
:
# bev_feat: (B, C, Dy, Dx)
# depth: (B*N_views, D, fH, fW)
bev_feat
,
depth
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
else
:
with
torch
.
no_grad
():
bev_feat
,
depth
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
else
:
# https://github.com/HuangJunJie2017/BEVDet/issues/275
bev_feat
=
torch
.
zeros_like
(
bev_feat_list
[
0
])
depth
=
None
bev_feat_list
.
append
(
bev_feat
)
depth_list
.
append
(
depth
)
key_frame
=
False
# bev_feat_list: List[(B, C, Dy, Dx), (B, C, Dy, Dx), ...]
# depth_list: List[(B*N_views, D, fH, fW), (B*N_views, D, fH, fW), ...]
if
pred_prev
:
assert
self
.
align_after_view_transfromation
assert
sensor2keyegos
[
0
].
shape
[
0
]
==
1
# batch_size = 1
feat_prev
=
torch
.
cat
(
bev_feat_list
[
1
:],
dim
=
0
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr
=
\
ego2globals
[
0
].
repeat
(
self
.
num_frame
-
1
,
1
,
1
,
1
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr
=
\
sensor2keyegos
[
0
].
repeat
(
self
.
num_frame
-
1
,
1
,
1
,
1
)
ego2globals_prev
=
torch
.
cat
(
ego2globals
[
1
:],
dim
=
0
)
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
=
torch
.
cat
(
sensor2keyegos
[
1
:],
dim
=
0
)
# (N_prev, N_views, 4, 4)
bda_curr
=
bda
.
repeat
(
self
.
num_frame
-
1
,
1
,
1
)
# (N_prev, 3, 3)
return
feat_prev
,
[
imgs
[
0
],
# (1, N_views, 3, H, W)
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
ego2globals_curr
,
# (N_prev, N_views, 4, 4)
intrins
[
0
],
# (1, N_views, 3, 3)
sensor2keyegos_prev
,
# (N_prev, N_views, 4, 4)
ego2globals_prev
,
# (N_prev, N_views, 4, 4)
post_rots
[
0
],
# (1, N_views, 3, 3)
post_trans
[
0
],
# (1, N_views, 3, )
bda_curr
]
# (N_prev, 3, 3)
if
self
.
align_after_view_transfromation
:
for
adj_id
in
range
(
1
,
self
.
num_frame
):
bev_feat_list
[
adj_id
]
=
self
.
shift_feature
(
bev_feat_list
[
adj_id
],
# (B, C, Dy, Dx)
[
sensor2keyegos
[
0
],
# (B, N_views, 4, 4)
sensor2keyegos
[
adj_id
]
# (B, N_views, 4, 4)
],
bda
# (B, 3, 3)
)
# (B, C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
# (B, N_frames*C, Dy, Dx)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth_list
[
0
]
projects/mmdet3d_plugin/models/detectors/bevdet_occ.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
from
...ops
import
TRTBEVPoolv2
from
.bevdet
import
BEVDet
from
.bevdepth
import
BEVDepth
from
.bevdepth4d
import
BEVDepth4D
from
.bevstereo4d
import
BEVStereo4D
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models.builder
import
build_head
import
torch.nn.functional
as
F
from
mmdet3d.core
import
bbox3d2result
import
numpy
as
np
from
multiprocessing.dummy
import
Pool
as
ThreadPool
from
...ops
import
nearest_assign
# pool = ThreadPool(processes=4) # 创建线程池
# for pano
grid_config_occ
=
{
'x'
:
[
-
40
,
40
,
0.4
],
'y'
:
[
-
40
,
40
,
0.4
],
'z'
:
[
-
1
,
5.4
,
6.4
],
'depth'
:
[
1.0
,
45.0
,
1.0
],
}
# det
det_class_name
=
[
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'barrier'
]
# occ
occ_class_names
=
[
'others'
,
'barrier'
,
'bicycle'
,
'bus'
,
'car'
,
'construction_vehicle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'trailer'
,
'truck'
,
'driveable_surface'
,
'other_flat'
,
'sidewalk'
,
'terrain'
,
'manmade'
,
'vegetation'
,
'free'
]
det_ind
=
[
2
,
3
,
4
,
5
,
6
,
7
,
9
,
10
]
occ_ind
=
[
5
,
3
,
0
,
4
,
6
,
7
,
2
,
1
]
detind2occind
=
{
0
:
4
,
1
:
10
,
2
:
9
,
3
:
3
,
4
:
5
,
5
:
2
,
6
:
6
,
7
:
7
,
8
:
8
,
9
:
1
,
}
occind2detind
=
{
4
:
0
,
10
:
1
,
9
:
2
,
3
:
3
,
5
:
4
,
2
:
5
,
6
:
6
,
7
:
7
,
8
:
8
,
1
:
9
,
}
occind2detind_cuda
=
[
-
1
,
-
1
,
5
,
3
,
0
,
4
,
6
,
7
,
-
1
,
2
,
1
]
inst_occ
=
np
.
ones
([
200
,
200
,
16
])
*
0
import
torch
X1
,
Y1
,
Z1
=
200
,
200
,
16
coords_x
=
torch
.
arange
(
X1
).
float
()
coords_y
=
torch
.
arange
(
Y1
).
float
()
coords_z
=
torch
.
arange
(
Z1
).
float
()
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_x
,
coords_y
,
coords_z
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
# coords = coords.cpu().numpy()
st
=
[
grid_config_occ
[
'x'
][
0
],
grid_config_occ
[
'y'
][
0
],
grid_config_occ
[
'z'
][
0
]]
sx
=
[
grid_config_occ
[
'x'
][
2
],
grid_config_occ
[
'y'
][
2
],
0.4
]
@
DETECTORS
.
register_module
()
class
BEVDetOCC
(
BEVDet
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVDetOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
#@torch.compile
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
loss_occ
=
self
.
forward_occ_train
(
occ_bev_feature
,
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
if
not
hasattr
(
self
.
occ_head
,
"get_occ_gpu"
):
occ_preds
=
self
.
occ_head
.
get_occ
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
else
:
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDepthOCC
(
BEVDepth
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVDepthOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
loss_occ
=
self
.
forward_occ_train
(
occ_bev_feature
,
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDepthPano
(
BEVDepthOCC
):
def
__init__
(
self
,
aux_centerness_head
=
None
,
**
kwargs
):
super
(
BEVDepthPano
,
self
).
__init__
(
**
kwargs
)
self
.
aux_centerness_head
=
None
if
aux_centerness_head
:
train_cfg
=
kwargs
[
'train_cfg'
]
test_cfg
=
kwargs
[
'test_cfg'
]
pts_train_cfg
=
train_cfg
.
pts
if
train_cfg
else
None
aux_centerness_head
.
update
(
train_cfg
=
pts_train_cfg
)
pts_test_cfg
=
test_cfg
.
pts
if
test_cfg
else
None
aux_centerness_head
.
update
(
test_cfg
=
pts_test_cfg
)
self
.
aux_centerness_head
=
build_head
(
aux_centerness_head
)
if
'inst_class_ids'
in
kwargs
:
self
.
inst_class_ids
=
kwargs
[
'inst_class_ids'
]
else
:
self
.
inst_class_ids
=
[
2
,
3
,
4
,
5
,
6
,
7
,
9
,
10
]
X1
,
Y1
,
Z1
=
200
,
200
,
16
coords_x
=
torch
.
arange
(
X1
).
float
()
coords_y
=
torch
.
arange
(
Y1
).
float
()
coords_z
=
torch
.
arange
(
Z1
).
float
()
self
.
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_x
,
coords_y
,
coords_z
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
self
.
st
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
0
],
grid_config_occ
[
'y'
][
0
],
grid_config_occ
[
'z'
][
0
]])
self
.
sx
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
2
],
grid_config_occ
[
'y'
][
2
],
0.4
])
self
.
is_to_d
=
False
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
loss_occ
=
self
.
forward_occ_train
(
occ_bev_feature
,
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
losses_aux_centerness
=
self
.
forward_aux_centerness_train
([
occ_bev_feature
],
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_aux_centerness
)
return
losses
def
forward_aux_centerness_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
):
outs
=
self
.
aux_centerness_head
(
pts_feats
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
aux_centerness_head
.
loss
(
*
loss_inputs
)
return
losses
def
simple_test_aux_centerness
(
self
,
x
,
img_metas
,
rescale
=
False
,
**
kwargs
):
"""Test function of point cloud branch."""
# outs = self.aux_centerness_head(x)
tx
=
self
.
aux_centerness_head
.
shared_conv
(
x
[
0
])
# (B, C'=share_conv_channel, H, W)
outs_inst_center_reg
=
self
.
aux_centerness_head
.
task_heads
[
0
].
reg
(
tx
)
outs_inst_center_height
=
self
.
aux_centerness_head
.
task_heads
[
0
].
height
(
tx
)
outs_inst_center_heatmap
=
self
.
aux_centerness_head
.
task_heads
[
0
].
heatmap
(
tx
)
outs
=
([{
"reg"
:
outs_inst_center_reg
,
"height"
:
outs_inst_center_height
,
"heatmap"
:
outs_inst_center_heatmap
,
}],)
# # bbox_list = self.aux_centerness_head.get_bboxes(
# # outs, img_metas, rescale=rescale)
# # bbox_results = [
# # bbox3d2result(bboxes, scores, labels)
# # for bboxes, scores, labels in bbox_list
# # ]
ins_cen_list
=
self
.
aux_centerness_head
.
get_centers
(
outs
,
img_metas
,
rescale
=
rescale
)
# return bbox_results, ins_cen_list
return
None
,
ins_cen_list
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
w_pano
=
kwargs
[
'w_pano'
]
if
'w_pano'
in
kwargs
else
True
if
w_pano
==
True
:
bbox_pts
,
ins_cen_list
=
self
.
simple_test_aux_centerness
([
occ_bev_feature
],
img_metas
,
rescale
=
rescale
,
**
kwargs
)
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for
result_dict
,
occ_pred
in
zip
(
result_list
,
occ_list
):
result_dict
[
'pred_occ'
]
=
occ_pred
w_panoproc
=
kwargs
[
'w_panoproc'
]
if
'w_panoproc'
in
kwargs
else
True
# 37.53 ms
if
w_panoproc
==
True
:
# # for pano
inst_xyz
=
ins_cen_list
[
0
][
0
]
if
self
.
is_to_d
==
False
:
self
.
st
=
self
.
st
.
to
(
inst_xyz
)
self
.
sx
=
self
.
sx
.
to
(
inst_xyz
)
self
.
coords
=
self
.
coords
.
to
(
inst_xyz
)
self
.
is_to_d
=
True
inst_xyz
=
((
inst_xyz
-
self
.
st
)
/
self
.
sx
).
int
()
inst_cls
=
ins_cen_list
[
2
][
0
].
int
()
inst_num
=
18
# 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ
=
occ_pred
.
clone
().
detach
()
# 37.61 ms
if
len
(
inst_cls
)
>
0
:
cls_sort
,
indices
=
inst_cls
.
sort
()
l2s
=
{}
if
len
(
inst_cls
)
==
1
:
l2s
[
cls_sort
[
0
].
item
()]
=
0
l2s
[
cls_sort
[
0
].
item
()]
=
0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list
=
(
cls_sort
[
1
:]
-
cls_sort
[:
-
1
])
!=
0
if
tind_list
.
__len__
()
>
0
:
for
tind
in
torch
.
range
(
0
,
len
(
tind_list
)
-
1
)[
tind_list
]:
l2s
[
cls_sort
[
1
+
int
(
tind
.
item
())].
item
()]
=
int
(
tind
.
item
())
+
1
is_cuda
=
True
# is_cuda = False
if
is_cuda
==
True
:
inst_id_list
=
indices
+
inst_num
l2s_key
=
indices
.
new_tensor
([
detind2occind
[
k
]
for
k
in
l2s
.
keys
()]).
to
(
torch
.
int
)
inst_occ
=
nearest_assign
(
occ_pred
.
to
(
torch
.
int
),
l2s_key
.
to
(
torch
.
int
),
indices
.
new_tensor
(
occind2detind_cuda
).
to
(
torch
.
int
),
inst_cls
.
to
(
torch
.
int
),
inst_xyz
.
to
(
torch
.
int
),
inst_id_list
.
to
(
torch
.
int
)
)
else
:
for
cls_label_num_in_occ
in
self
.
inst_class_ids
:
mask
=
occ_pred
==
cls_label_num_in_occ
if
mask
.
sum
()
==
0
:
continue
else
:
cls_label_num_in_inst
=
occind2detind
[
cls_label_num_in_occ
]
select_mask
=
inst_cls
==
cls_label_num_in_inst
if
sum
(
select_mask
)
>
0
:
indices
=
self
.
coords
[
mask
]
inst_index_same_cls
=
inst_xyz
[
select_mask
]
select_ind
=
((
indices
[:,
None
,:]
-
inst_index_same_cls
[
None
,:,:])
**
2
).
sum
(
-
1
).
argmin
(
axis
=
1
).
int
()
inst_occ
[
mask
]
=
select_ind
+
inst_num
+
l2s
[
cls_label_num_in_inst
]
result_list
[
0
][
'pano_inst'
]
=
inst_occ
#.cpu().numpy()
return
result_list
@
DETECTORS
.
register_module
()
class
BEVDepth4DOCC
(
BEVDepth4D
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVDepth4DOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
losses
=
dict
()
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
loss_occ
=
self
.
forward_occ_train
(
img_feats
[
0
],
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
assert
voxel_semantics
.
min
()
>=
0
and
voxel_semantics
.
max
()
<=
17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_list
=
self
.
simple_test_occ
(
img_feats
[
0
],
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDepth4DPano
(
BEVDepth4DOCC
):
def
__init__
(
self
,
aux_centerness_head
=
None
,
**
kwargs
):
super
(
BEVDepth4DPano
,
self
).
__init__
(
**
kwargs
)
self
.
aux_centerness_head
=
None
if
aux_centerness_head
:
train_cfg
=
kwargs
[
'train_cfg'
]
test_cfg
=
kwargs
[
'test_cfg'
]
pts_train_cfg
=
train_cfg
.
pts
if
train_cfg
else
None
aux_centerness_head
.
update
(
train_cfg
=
pts_train_cfg
)
pts_test_cfg
=
test_cfg
.
pts
if
test_cfg
else
None
aux_centerness_head
.
update
(
test_cfg
=
pts_test_cfg
)
self
.
aux_centerness_head
=
build_head
(
aux_centerness_head
)
if
'inst_class_ids'
in
kwargs
:
self
.
inst_class_ids
=
kwargs
[
'inst_class_ids'
]
else
:
self
.
inst_class_ids
=
[
2
,
3
,
4
,
5
,
6
,
7
,
9
,
10
]
X1
,
Y1
,
Z1
=
200
,
200
,
16
coords_x
=
torch
.
arange
(
X1
).
float
()
coords_y
=
torch
.
arange
(
Y1
).
float
()
coords_z
=
torch
.
arange
(
Z1
).
float
()
self
.
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_x
,
coords_y
,
coords_z
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
self
.
st
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
0
],
grid_config_occ
[
'y'
][
0
],
grid_config_occ
[
'z'
][
0
]])
self
.
sx
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
2
],
grid_config_occ
[
'y'
][
2
],
0.4
])
self
.
is_to_d
=
False
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
losses
=
dict
()
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
loss_occ
=
self
.
forward_occ_train
(
img_feats
[
0
],
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
losses_aux_centerness
=
self
.
forward_aux_centerness_train
([
img_feats
[
0
]],
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_aux_centerness
)
return
losses
def
forward_aux_centerness_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
):
outs
=
self
.
aux_centerness_head
(
pts_feats
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
aux_centerness_head
.
loss
(
*
loss_inputs
)
return
losses
def
simple_test_aux_centerness
(
self
,
x
,
img_metas
,
rescale
=
False
,
**
kwargs
):
"""Test function of point cloud branch."""
outs
=
self
.
aux_centerness_head
(
x
)
bbox_list
=
self
.
aux_centerness_head
.
get_bboxes
(
outs
,
img_metas
,
rescale
=
rescale
)
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
ins_cen_list
=
self
.
aux_centerness_head
.
get_centers
(
outs
,
img_metas
,
rescale
=
rescale
)
return
bbox_results
,
ins_cen_list
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
w_pano
=
kwargs
[
'w_pano'
]
if
'w_pano'
in
kwargs
else
True
if
w_pano
==
True
:
bbox_pts
,
ins_cen_list
=
self
.
simple_test_aux_centerness
([
occ_bev_feature
],
img_metas
,
rescale
=
rescale
,
**
kwargs
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for
result_dict
,
occ_pred
in
zip
(
result_list
,
occ_list
):
result_dict
[
'pred_occ'
]
=
occ_pred
w_panoproc
=
kwargs
[
'w_panoproc'
]
if
'w_panoproc'
in
kwargs
else
True
if
w_panoproc
==
True
:
# # for pano
inst_xyz
=
ins_cen_list
[
0
][
0
]
if
self
.
is_to_d
==
False
:
self
.
st
=
self
.
st
.
to
(
inst_xyz
)
self
.
sx
=
self
.
sx
.
to
(
inst_xyz
)
self
.
coords
=
self
.
coords
.
to
(
inst_xyz
)
self
.
is_to_d
=
True
inst_xyz
=
((
inst_xyz
-
self
.
st
)
/
self
.
sx
).
int
()
inst_cls
=
ins_cen_list
[
2
][
0
].
int
()
inst_num
=
18
# 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ
=
occ_pred
.
clone
().
detach
()
# 37.61 ms
if
len
(
inst_cls
)
>
0
:
cls_sort
,
indices
=
inst_cls
.
sort
()
l2s
=
{}
if
len
(
inst_cls
)
==
1
:
l2s
[
cls_sort
[
0
].
item
()]
=
0
l2s
[
cls_sort
[
0
].
item
()]
=
0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list
=
(
cls_sort
[
1
:]
-
cls_sort
[:
-
1
])
!=
0
if
tind_list
.
__len__
()
>
0
:
for
tind
in
torch
.
range
(
0
,
len
(
tind_list
)
-
1
)[
tind_list
]:
l2s
[
cls_sort
[
1
+
int
(
tind
.
item
())].
item
()]
=
int
(
tind
.
item
())
+
1
is_cuda
=
True
# is_cuda = False
if
is_cuda
==
True
:
inst_id_list
=
indices
+
inst_num
l2s_key
=
indices
.
new_tensor
([
detind2occind
[
k
]
for
k
in
l2s
.
keys
()]).
to
(
torch
.
int
)
inst_occ
=
nearest_assign
(
occ_pred
.
to
(
torch
.
int
),
l2s_key
.
to
(
torch
.
int
),
indices
.
new_tensor
(
occind2detind_cuda
).
to
(
torch
.
int
),
inst_cls
.
to
(
torch
.
int
),
inst_xyz
.
to
(
torch
.
int
),
inst_id_list
.
to
(
torch
.
int
)
)
else
:
for
cls_label_num_in_occ
in
self
.
inst_class_ids
:
mask
=
occ_pred
==
cls_label_num_in_occ
if
mask
.
sum
()
==
0
:
continue
else
:
cls_label_num_in_inst
=
occind2detind
[
cls_label_num_in_occ
]
select_mask
=
inst_cls
==
cls_label_num_in_inst
if
sum
(
select_mask
)
>
0
:
indices
=
self
.
coords
[
mask
]
inst_index_same_cls
=
inst_xyz
[
select_mask
]
select_ind
=
((
indices
[:,
None
,:]
-
inst_index_same_cls
[
None
,:,:])
**
2
).
sum
(
-
1
).
argmin
(
axis
=
1
).
int
()
inst_occ
[
mask
]
=
select_ind
+
inst_num
+
l2s
[
cls_label_num_in_inst
]
result_list
[
0
][
'pano_inst'
]
=
inst_occ
#.cpu().numpy()
return
result_list
@
DETECTORS
.
register_module
()
class
BEVStereo4DOCC
(
BEVStereo4D
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVStereo4DOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
losses
=
dict
()
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
loss_occ
=
self
.
forward_occ_train
(
img_feats
[
0
],
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
assert
voxel_semantics
.
min
()
>=
0
and
voxel_semantics
.
max
()
<=
17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_list
=
self
.
simple_test_occ
(
img_feats
[
0
],
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDetOCCTRT
(
BEVDetOCC
):
def
__init__
(
self
,
wocc
=
True
,
wdet3d
=
True
,
uni_train
=
True
,
**
kwargs
):
super
(
BEVDetOCCTRT
,
self
).
__init__
(
**
kwargs
)
self
.
wocc
=
wocc
self
.
wdet3d
=
wdet3d
self
.
uni_train
=
uni_train
def
result_serialize
(
self
,
outs_det3d
=
None
,
outs_occ
=
None
):
outs_
=
[]
if
outs_det3d
is
not
None
:
for
out
in
outs_det3d
:
for
key
in
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]:
outs_
.
append
(
out
[
0
][
key
])
if
outs_occ
is
not
None
:
outs_
.
append
(
outs_occ
)
return
outs_
def
result_deserialize
(
self
,
outs
):
outs_
=
[]
keys
=
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]
for
head_id
in
range
(
len
(
outs
)
//
6
):
outs_head
=
[
dict
()]
for
kid
,
key
in
enumerate
(
keys
):
outs_head
[
0
][
key
]
=
outs
[
head_id
*
6
+
kid
]
outs_
.
append
(
outs_head
)
return
outs_
def
forward_part1
(
self
,
img
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
])
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return
tran_feat
.
flatten
(
0
,
2
),
depth
.
reshape
(
-
1
)
def
forward_part2
(
self
,
tran_feat
,
depth
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
tran_feat
=
tran_feat
.
reshape
(
6
,
16
,
44
,
64
)
depth
=
depth
.
reshape
(
6
,
16
,
44
,
44
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
# -> [1, 64, 200, 200]
return
x
.
reshape
(
-
1
)
def
forward_part3
(
self
,
x
):
x
=
x
.
reshape
(
1
,
200
,
200
,
64
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
return
outs
def
forward_ori
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
])
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
return
outs
def
forward_with_argmax
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
outs
=
self
.
forward_ori
(
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
)
pred_occ_label
=
outs
[
0
].
argmax
(
-
1
)
return
pred_occ_label
def
get_bev_pool_input
(
self
,
input
):
input
=
self
.
prepare_inputs
(
input
)
coor
=
self
.
img_view_transformer
.
get_lidar_coor
(
*
input
[
1
:
7
])
return
self
.
img_view_transformer
.
voxel_pooling_prepare_v2
(
coor
)
@
DETECTORS
.
register_module
()
class
BEVDepthOCCTRT
(
BEVDetOCC
):
def
__init__
(
self
,
wocc
=
True
,
wdet3d
=
True
,
uni_train
=
True
,
**
kwargs
):
super
(
BEVDepthOCCTRT
,
self
).
__init__
(
**
kwargs
)
self
.
wocc
=
wocc
self
.
wdet3d
=
wdet3d
self
.
uni_train
=
uni_train
def
result_serialize
(
self
,
outs_det3d
=
None
,
outs_occ
=
None
):
outs_
=
[]
if
outs_det3d
is
not
None
:
for
out
in
outs_det3d
:
for
key
in
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]:
outs_
.
append
(
out
[
0
][
key
])
if
outs_occ
is
not
None
:
outs_
.
append
(
outs_occ
)
return
outs_
def
result_deserialize
(
self
,
outs
):
outs_
=
[]
keys
=
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]
for
head_id
in
range
(
len
(
outs
)
//
6
):
outs_head
=
[
dict
()]
for
kid
,
key
in
enumerate
(
keys
):
outs_head
[
0
][
key
]
=
outs
[
head_id
*
6
+
kid
]
outs_
.
append
(
outs_head
)
return
outs_
def
forward_ori
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
],
mlp_input
)
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
return
outs
def
forward_with_argmax
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
outs
=
self
.
forward_ori
(
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
)
pred_occ_label
=
outs
[
0
].
argmax
(
-
1
)
return
pred_occ_label
def
get_bev_pool_input
(
self
,
input
):
input
=
self
.
prepare_inputs
(
input
)
coor
=
self
.
img_view_transformer
.
get_lidar_coor
(
*
input
[
1
:
7
])
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
*
input
[
1
:
7
])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return
self
.
img_view_transformer
.
voxel_pooling_prepare_v2
(
coor
),
mlp_input
@
DETECTORS
.
register_module
()
class
BEVDepthPanoTRT
(
BEVDepthPano
):
def
__init__
(
self
,
wocc
=
True
,
wdet3d
=
True
,
uni_train
=
True
,
**
kwargs
):
super
(
BEVDepthPanoTRT
,
self
).
__init__
(
**
kwargs
)
self
.
wocc
=
wocc
self
.
wdet3d
=
wdet3d
self
.
uni_train
=
uni_train
def
result_serialize
(
self
,
outs_det3d
=
None
,
outs_occ
=
None
):
outs_
=
[]
if
outs_det3d
is
not
None
:
for
out
in
outs_det3d
:
for
key
in
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]:
outs_
.
append
(
out
[
0
][
key
])
if
outs_occ
is
not
None
:
outs_
.
append
(
outs_occ
)
return
outs_
def
result_deserialize
(
self
,
outs
):
outs_
=
[]
keys
=
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]
for
head_id
in
range
(
len
(
outs
)
//
6
):
outs_head
=
[
dict
()]
for
kid
,
key
in
enumerate
(
keys
):
outs_head
[
0
][
key
]
=
outs
[
head_id
*
6
+
kid
]
outs_
.
append
(
outs_head
)
return
outs_
def
forward_part1
(
self
,
img
,
mlp_input
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
],
mlp_input
)
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return
tran_feat
.
flatten
(
0
,
2
),
depth
.
reshape
(
-
1
)
def
forward_part2
(
self
,
tran_feat
,
depth
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
tran_feat
=
tran_feat
.
reshape
(
6
,
16
,
44
,
64
)
depth
=
depth
.
reshape
(
6
,
16
,
44
,
44
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
# -> [1, 64, 200, 200]
return
x
.
reshape
(
-
1
)
def
forward_part3
(
self
,
x
):
x
=
x
.
reshape
(
1
,
200
,
200
,
64
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x
=
self
.
aux_centerness_head
.
shared_conv
(
occ_bev_feature
)
# (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg
=
self
.
aux_centerness_head
.
task_heads
[
0
].
reg
(
x
)
outs
.
append
(
outs_inst_center_reg
)
outs_inst_center_height
=
self
.
aux_centerness_head
.
task_heads
[
0
].
height
(
x
)
outs
.
append
(
outs_inst_center_height
)
outs_inst_center_heatmap
=
self
.
aux_centerness_head
.
task_heads
[
0
].
heatmap
(
x
)
outs
.
append
(
outs_inst_center_heatmap
)
def
forward_ori
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
],
mlp_input
)
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x
=
self
.
aux_centerness_head
.
shared_conv
(
occ_bev_feature
)
# (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg
=
self
.
aux_centerness_head
.
task_heads
[
0
].
reg
(
x
)
outs
.
append
(
outs_inst_center_reg
)
outs_inst_center_height
=
self
.
aux_centerness_head
.
task_heads
[
0
].
height
(
x
)
outs
.
append
(
outs_inst_center_height
)
outs_inst_center_heatmap
=
self
.
aux_centerness_head
.
task_heads
[
0
].
heatmap
(
x
)
outs
.
append
(
outs_inst_center_heatmap
)
return
outs
def
forward_with_argmax
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
outs
=
self
.
forward_ori
(
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
)
pred_occ_label
=
outs
[
0
].
argmax
(
-
1
)
return
pred_occ_label
,
*
outs
[
1
:]
def
get_bev_pool_input
(
self
,
input
):
input
=
self
.
prepare_inputs
(
input
)
coor
=
self
.
img_view_transformer
.
get_lidar_coor
(
*
input
[
1
:
7
])
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
*
input
[
1
:
7
])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return
self
.
img_view_transformer
.
voxel_pooling_prepare_v2
(
coor
),
mlp_input
projects/mmdet3d_plugin/models/detectors/bevstereo4d.py
0 → 100644
View file @
3b8d508a
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
builder
from
.bevdepth4d
import
BEVDepth4D
from
mmdet.models.backbones.resnet
import
ResNet
@
DETECTORS
.
register_module
()
class
BEVStereo4D
(
BEVDepth4D
):
def
__init__
(
self
,
**
kwargs
):
super
(
BEVStereo4D
,
self
).
__init__
(
**
kwargs
)
self
.
extra_ref_frames
=
1
self
.
temporal_frame
=
self
.
num_frame
self
.
num_frame
+=
self
.
extra_ref_frames
def
extract_stereo_ref_feat
(
self
,
x
):
"""
Args:
x: (B, N_views, 3, H, W)
Returns:
x: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
B
,
N
,
C
,
imH
,
imW
=
x
.
shape
x
=
x
.
view
(
B
*
N
,
C
,
imH
,
imW
)
# (B*N_views, 3, H, W)
if
isinstance
(
self
.
img_backbone
,
ResNet
):
if
self
.
img_backbone
.
deep_stem
:
x
=
self
.
img_backbone
.
stem
(
x
)
else
:
x
=
self
.
img_backbone
.
conv1
(
x
)
x
=
self
.
img_backbone
.
norm1
(
x
)
x
=
self
.
img_backbone
.
relu
(
x
)
x
=
self
.
img_backbone
.
maxpool
(
x
)
for
i
,
layer_name
in
enumerate
(
self
.
img_backbone
.
res_layers
):
res_layer
=
getattr
(
self
.
img_backbone
,
layer_name
)
x
=
res_layer
(
x
)
return
x
else
:
x
=
self
.
img_backbone
.
patch_embed
(
x
)
hw_shape
=
(
self
.
img_backbone
.
patch_embed
.
DH
,
self
.
img_backbone
.
patch_embed
.
DW
)
if
self
.
img_backbone
.
use_abs_pos_embed
:
x
=
x
+
self
.
img_backbone
.
absolute_pos_embed
x
=
self
.
img_backbone
.
drop_after_pos
(
x
)
for
i
,
stage
in
enumerate
(
self
.
img_backbone
.
stages
):
x
,
hw_shape
,
out
,
out_hw_shape
=
stage
(
x
,
hw_shape
)
out
=
out
.
view
(
-
1
,
*
out_hw_shape
,
self
.
img_backbone
.
num_features
[
i
])
out
=
out
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
out
def
prepare_bev_feat
(
self
,
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
,
feat_prev_iv
,
k2s_sensor
,
extra_ref_frame
):
"""
Args:
img: (B, N_views, 3, H, W)
sensor2keyego: (B, N_views, 4, 4)
ego2global: (B, N_views, 4, 4)
intrin: (B, N_views, 3, 3)
post_rot: (B, N_views, 3, 3)
post_tran: (B, N_views, 3)
bda: (B, 3, 3)
mlp_input: (B, N_views, 27)
feat_prev_iv: (B*N_views, C_stereo, fH_stereo, fW_stereo) or None
k2s_sensor: (B, N_views, 4, 4) or None
extra_ref_frame:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
stereo_feat: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
if
extra_ref_frame
:
stereo_feat
=
self
.
extract_stereo_ref_feat
(
img
)
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
return
None
,
None
,
stereo_feat
# x: (B, N_views, C, fH, fW)
# stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo)
x
,
stereo_feat
=
self
.
image_encoder
(
img
,
stereo
=
True
)
# 建立cost volume 所需的信息.
metas
=
dict
(
k2s_sensor
=
k2s_sensor
,
# (B, N_views, 4, 4)
intrins
=
intrin
,
# (B, N_views, 3, 3)
post_rots
=
post_rot
,
# (B, N_views, 3, 3)
post_trans
=
post_tran
,
# (B, N_views, 3)
frustum
=
self
.
img_view_transformer
.
cv_frustum
.
to
(
x
),
# (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample
=
4
,
downsample
=
self
.
img_view_transformer
.
downsample
,
grid_config
=
self
.
img_view_transformer
.
grid_config
,
cv_feat_list
=
[
feat_prev_iv
,
stereo_feat
]
)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat
,
depth
=
self
.
img_view_transformer
(
[
x
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
],
metas
)
if
self
.
pre_process
:
bev_feat
=
self
.
pre_process_net
(
bev_feat
)[
0
]
# (B, C, Dy, Dx)
return
bev_feat
,
depth
,
stereo_feat
def
extract_img_feat_sequential
(
self
,
inputs
,
feat_prev
):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev_iv:
curr2adjsensor: (1, N_views, 4, 4)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs
,
sensor2keyegos_curr
,
ego2globals_curr
,
intrins
=
inputs
[:
4
]
sensor2keyegos_prev
,
_
,
post_rots
,
post_trans
,
bda
=
inputs
[
4
:
9
]
feat_prev_iv
,
curr2adjsensor
=
inputs
[
9
:]
bev_feat_list
=
[]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...])
inputs_curr
=
(
imgs
,
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...],
mlp_input
,
feat_prev_iv
,
curr2adjsensor
,
False
)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat
,
depth
,
_
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
bev_feat_list
.
append
(
bev_feat
)
# align the feat_prev
_
,
C
,
H
,
W
=
feat_prev
.
shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev
=
\
self
.
shift_feature
(
feat_prev
,
# (N_prev, C, Dy, Dx)
[
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
],
# (N_prev, N_views, 4, 4)
bda
# (N_prev, 3, 3)
)
bev_feat_list
.
append
(
feat_prev
.
view
(
1
,
(
self
.
num_frame
-
2
)
*
C
,
H
,
W
))
# (1, N_prev*C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
# (1, N_frames*C, Dy, Dx)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
pred_prev
=
False
,
sequential
=
False
,
**
kwargs
):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if
sequential
:
return
self
.
extract_img_feat_sequential
(
img_inputs
,
kwargs
[
'feat_prev'
])
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
\
bda
,
curr2adjsensor
=
self
.
prepare_inputs
(
img_inputs
,
stereo
=
True
)
"""Extract features of images."""
bev_feat_list
=
[]
depth_key_frame
=
None
feat_prev_iv
=
None
for
fid
in
range
(
self
.
num_frame
-
1
,
-
1
,
-
1
):
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
=
\
imgs
[
fid
],
sensor2keyegos
[
fid
],
ego2globals
[
fid
],
intrins
[
fid
],
\
post_rots
[
fid
],
post_trans
[
fid
]
key_frame
=
fid
==
0
extra_ref_frame
=
fid
==
self
.
num_frame
-
self
.
extra_ref_frames
if
key_frame
or
self
.
with_prev
:
if
self
.
align_after_view_transfromation
:
sensor2keyego
,
ego2global
=
sensor2keyegos
[
0
],
ego2globals
[
0
]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos
[
0
],
ego2globals
[
0
],
intrin
,
post_rot
,
post_tran
,
bda
)
# (B, N_views, 27)
inputs_curr
=
(
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
,
feat_prev_iv
,
curr2adjsensor
[
fid
],
extra_ref_frame
)
if
key_frame
:
bev_feat
,
depth
,
feat_curr_iv
=
\
self
.
prepare_bev_feat
(
*
inputs_curr
)
depth_key_frame
=
depth
else
:
with
torch
.
no_grad
():
bev_feat
,
depth
,
feat_curr_iv
=
\
self
.
prepare_bev_feat
(
*
inputs_curr
)
if
not
extra_ref_frame
:
bev_feat_list
.
append
(
bev_feat
)
if
not
key_frame
:
feat_prev_iv
=
feat_curr_iv
if
pred_prev
:
assert
self
.
align_after_view_transfromation
assert
sensor2keyegos
[
0
].
shape
[
0
]
==
1
# batch_size = 1
feat_prev
=
torch
.
cat
(
bev_feat_list
[
1
:],
dim
=
0
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr
=
\
ego2globals
[
0
].
repeat
(
self
.
num_frame
-
2
,
1
,
1
,
1
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr
=
\
sensor2keyegos
[
0
].
repeat
(
self
.
num_frame
-
2
,
1
,
1
,
1
)
ego2globals_prev
=
torch
.
cat
(
ego2globals
[
1
:
-
1
],
dim
=
0
)
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
=
torch
.
cat
(
sensor2keyegos
[
1
:
-
1
],
dim
=
0
)
# (N_prev, N_views, 4, 4)
bda_curr
=
bda
.
repeat
(
self
.
num_frame
-
2
,
1
,
1
)
# (N_prev, 3, 3)
return
feat_prev
,
[
imgs
[
0
],
# (1, N_views, 3, H, W)
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
ego2globals_curr
,
# (N_prev, N_views, 4, 4)
intrins
[
0
],
# (1, N_views, 3, 3)
sensor2keyegos_prev
,
# (N_prev, N_views, 4, 4)
ego2globals_prev
,
# (N_prev, N_views, 4, 4)
post_rots
[
0
],
# (1, N_views, 3, 3)
post_trans
[
0
],
# (1, N_views, 3, )
bda_curr
,
# (N_prev, 3, 3)
feat_prev_iv
,
curr2adjsensor
[
0
]]
if
not
self
.
with_prev
:
bev_feat_key
=
bev_feat_list
[
0
]
if
len
(
bev_feat_key
.
shape
)
==
4
:
b
,
c
,
h
,
w
=
bev_feat_key
.
shape
bev_feat_list
=
\
[
torch
.
zeros
([
b
,
c
*
(
self
.
num_frame
-
self
.
extra_ref_frames
-
1
),
h
,
w
]).
to
(
bev_feat_key
),
bev_feat_key
]
else
:
b
,
c
,
z
,
h
,
w
=
bev_feat_key
.
shape
bev_feat_list
=
\
[
torch
.
zeros
([
b
,
c
*
(
self
.
num_frame
-
self
.
extra_ref_frames
-
1
),
z
,
h
,
w
]).
to
(
bev_feat_key
),
bev_feat_key
]
if
self
.
align_after_view_transfromation
:
for
adj_id
in
range
(
self
.
num_frame
-
2
):
bev_feat_list
[
adj_id
]
=
self
.
shift_feature
(
bev_feat_list
[
adj_id
],
# (B, C, Dy, Dx)
[
sensor2keyegos
[
0
],
# (B, N_views, 4, 4)
sensor2keyegos
[
self
.
num_frame
-
2
-
adj_id
]],
# (B, N_views, 4, 4)
bda
# (B, 3, 3)
)
# (B, C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth_key_frame
projects/mmdet3d_plugin/models/losses/__init__.py
0 → 100644
View file @
3b8d508a
from
.cross_entropy_loss
import
CrossEntropyLoss
from
.focal_loss
import
CustomFocalLoss
__all__
=
[
'CrossEntropyLoss'
,
'CustomFocalLoss'
]
\ No newline at end of file
projects/mmdet3d_plugin/models/losses/cross_entropy_loss.py
0 → 100644
View file @
3b8d508a
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models.builder
import
LOSSES
from
mmdet.models.losses.utils
import
weight_reduce_loss
def
cross_entropy
(
pred
,
label
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
class_weight
=
None
,
ignore_index
=-
100
,
avg_non_ignore
=
False
):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index
=
-
100
if
ignore_index
is
None
else
ignore_index
# element-wise losses
loss
=
F
.
cross_entropy
(
pred
,
label
,
weight
=
class_weight
,
reduction
=
'none'
,
ignore_index
=
ignore_index
)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if
(
avg_factor
is
None
)
and
avg_non_ignore
and
reduction
==
'mean'
:
avg_factor
=
label
.
numel
()
-
(
label
==
ignore_index
).
sum
().
item
()
# apply weights and do the reduction
if
weight
is
not
None
:
weight
=
weight
.
float
()
loss
=
weight_reduce_loss
(
loss
,
weight
=
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
def
_expand_onehot_labels
(
labels
,
label_weights
,
label_channels
,
ignore_index
):
"""Expand onehot labels to match the size of prediction."""
bin_labels
=
labels
.
new_full
((
labels
.
size
(
0
),
label_channels
),
0
)
valid_mask
=
(
labels
>=
0
)
&
(
labels
!=
ignore_index
)
inds
=
torch
.
nonzero
(
valid_mask
&
(
labels
<
label_channels
),
as_tuple
=
False
)
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]]
=
1
valid_mask
=
valid_mask
.
view
(
-
1
,
1
).
expand
(
labels
.
size
(
0
),
label_channels
).
float
()
if
label_weights
is
None
:
bin_label_weights
=
valid_mask
else
:
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
repeat
(
1
,
label_channels
)
bin_label_weights
*=
valid_mask
return
bin_labels
,
bin_label_weights
,
valid_mask
def
binary_cross_entropy
(
pred
,
label
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
class_weight
=
None
,
ignore_index
=-
100
,
avg_non_ignore
=
False
):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index
=
-
100
if
ignore_index
is
None
else
ignore_index
if
pred
.
dim
()
!=
label
.
dim
():
label
,
weight
,
valid_mask
=
_expand_onehot_labels
(
label
,
weight
,
pred
.
size
(
-
1
),
ignore_index
)
else
:
# should mask out the ignored elements
valid_mask
=
((
label
>=
0
)
&
(
label
!=
ignore_index
)).
float
()
if
weight
is
not
None
:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight
=
weight
*
valid_mask
else
:
weight
=
valid_mask
# average loss over non-ignored elements
if
(
avg_factor
is
None
)
and
avg_non_ignore
and
reduction
==
'mean'
:
avg_factor
=
valid_mask
.
sum
().
item
()
# weighted element-wise losses
weight
=
weight
.
float
()
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
label
.
float
(),
pos_weight
=
class_weight
,
reduction
=
'none'
)
# do the reduction for the weighted loss
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
def
mask_cross_entropy
(
pred
,
target
,
label
,
reduction
=
'mean'
,
avg_factor
=
None
,
class_weight
=
None
,
ignore_index
=
None
,
**
kwargs
):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert
ignore_index
is
None
,
'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert
reduction
==
'mean'
and
avg_factor
is
None
num_rois
=
pred
.
size
()[
0
]
inds
=
torch
.
arange
(
0
,
num_rois
,
dtype
=
torch
.
long
,
device
=
pred
.
device
)
pred_slice
=
pred
[
inds
,
label
].
squeeze
(
1
)
return
F
.
binary_cross_entropy_with_logits
(
pred_slice
,
target
,
weight
=
class_weight
,
reduction
=
'mean'
)[
None
]
@
LOSSES
.
register_module
(
force
=
True
)
class
CrossEntropyLoss
(
nn
.
Module
):
def
__init__
(
self
,
use_sigmoid
=
False
,
use_mask
=
False
,
reduction
=
'mean'
,
class_weight
=
None
,
ignore_index
=
None
,
loss_weight
=
1.0
,
avg_non_ignore
=
False
):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super
(
CrossEntropyLoss
,
self
).
__init__
()
assert
(
use_sigmoid
is
False
)
or
(
use_mask
is
False
)
self
.
use_sigmoid
=
use_sigmoid
self
.
use_mask
=
use_mask
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
self
.
avg_non_ignore
=
avg_non_ignore
if
((
ignore_index
is
not
None
)
and
not
self
.
avg_non_ignore
and
self
.
reduction
==
'mean'
):
warnings
.
warn
(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.'
)
if
self
.
use_sigmoid
:
self
.
cls_criterion
=
binary_cross_entropy
elif
self
.
use_mask
:
self
.
cls_criterion
=
mask_cross_entropy
else
:
self
.
cls_criterion
=
cross_entropy
def
extra_repr
(
self
):
"""Extra repr."""
s
=
f
'avg_non_ignore=
{
self
.
avg_non_ignore
}
'
return
s
def
forward
(
self
,
cls_score
,
label
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
,
ignore_index
=
None
,
**
kwargs
):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
if
ignore_index
is
None
:
ignore_index
=
self
.
ignore_index
if
self
.
class_weight
is
not
None
:
class_weight
=
cls_score
.
new_tensor
(
self
.
class_weight
,
device
=
cls_score
.
device
)
else
:
class_weight
=
None
loss_cls
=
self
.
loss_weight
*
self
.
cls_criterion
(
cls_score
,
label
,
weight
,
class_weight
=
class_weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
ignore_index
=
ignore_index
,
avg_non_ignore
=
self
.
avg_non_ignore
,
**
kwargs
)
return
loss_cls
projects/mmdet3d_plugin/models/losses/focal_loss.py
0 → 100644
View file @
3b8d508a
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.ops
import
sigmoid_focal_loss
as
_sigmoid_focal_loss
from
mmdet.models.builder
import
LOSSES
from
mmdet.models.losses.utils
import
weight_reduce_loss
import
numpy
as
np
# This method is only for debugging
def
py_sigmoid_focal_loss
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid
=
pred
.
sigmoid
()
target
=
target
.
type_as
(
pred
)
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
focal_weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
pt
.
pow
(
gamma
)
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
target
,
reduction
=
'none'
)
*
focal_weight
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
loss
*
weight
loss
=
loss
.
sum
(
-
1
).
mean
()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return
loss
def
py_focal_loss_with_prob
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
Args:
pred (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
num_classes
=
pred
.
size
(
1
)
target
=
F
.
one_hot
(
target
,
num_classes
=
num_classes
+
1
)
target
=
target
[:,
:
num_classes
]
target
=
target
.
type_as
(
pred
)
pt
=
(
1
-
pred
)
*
target
+
pred
*
(
1
-
target
)
focal_weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
pt
.
pow
(
gamma
)
loss
=
F
.
binary_cross_entropy
(
pred
,
target
,
reduction
=
'none'
)
*
focal_weight
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
,
avg_factor
)
return
loss
def
sigmoid_focal_loss
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
r
"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss
=
_sigmoid_focal_loss
(
pred
.
contiguous
(),
target
.
contiguous
(),
gamma
,
alpha
,
None
,
'none'
)
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
loss
*
weight
loss
=
loss
.
sum
(
-
1
).
mean
()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return
loss
@
LOSSES
.
register_module
()
class
CustomFocalLoss
(
nn
.
Module
):
def
__init__
(
self
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
loss_weight
=
100.0
,
activated
=
False
):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
activated (bool, optional): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
super
(
CustomFocalLoss
,
self
).
__init__
()
assert
use_sigmoid
is
True
,
'Only sigmoid focal loss supported now.'
self
.
use_sigmoid
=
use_sigmoid
self
.
gamma
=
gamma
self
.
alpha
=
alpha
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
activated
=
activated
H
,
W
=
200
,
200
xy
,
yx
=
torch
.
meshgrid
([
torch
.
arange
(
H
)
-
H
/
2
,
torch
.
arange
(
W
)
-
W
/
2
])
c
=
torch
.
stack
([
xy
,
yx
],
2
)
c
=
torch
.
norm
(
c
,
2
,
-
1
)
c_max
=
c
.
max
()
self
.
c
=
(
c
/
c_max
+
1
).
cuda
()
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
ignore_index
=
255
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
B
,
H
,
W
,
D
=
target
.
shape
c
=
self
.
c
[
None
,
:,
:,
None
].
repeat
(
B
,
1
,
1
,
D
).
reshape
(
-
1
)
visible_mask
=
(
target
!=
ignore_index
).
reshape
(
-
1
).
nonzero
().
squeeze
(
-
1
)
weight_mask
=
weight
[
None
,
:]
*
c
[
visible_mask
,
None
]
# visible_mask[:, None]
num_classes
=
pred
.
size
(
1
)
pred
=
pred
.
permute
(
0
,
2
,
3
,
4
,
1
).
reshape
(
-
1
,
num_classes
)[
visible_mask
]
target
=
target
.
reshape
(
-
1
)[
visible_mask
]
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
if
self
.
use_sigmoid
:
if
self
.
activated
:
calculate_loss_func
=
py_focal_loss_with_prob
else
:
if
torch
.
cuda
.
is_available
()
and
pred
.
is_cuda
:
calculate_loss_func
=
sigmoid_focal_loss
else
:
num_classes
=
pred
.
size
(
1
)
target
=
F
.
one_hot
(
target
,
num_classes
=
num_classes
+
1
)
target
=
target
[:,
:
num_classes
]
calculate_loss_func
=
py_sigmoid_focal_loss
loss_cls
=
self
.
loss_weight
*
calculate_loss_func
(
pred
,
target
.
to
(
torch
.
long
),
weight_mask
,
gamma
=
self
.
gamma
,
alpha
=
self
.
alpha
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
else
:
raise
NotImplementedError
return
loss_cls
projects/mmdet3d_plugin/models/losses/lovasz_softmax.py
0 → 100644
View file @
3b8d508a
# -*- coding:utf-8 -*-
# author: Xinge
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from
__future__
import
print_function
,
division
import
torch
from
torch.autograd
import
Variable
import
torch.nn.functional
as
F
import
numpy
as
np
try
:
from
itertools
import
ifilterfalse
except
ImportError
:
# py3k
from
itertools
import
filterfalse
as
ifilterfalse
from
torch.cuda.amp
import
autocast
def
lovasz_grad
(
gt_sorted
):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p
=
len
(
gt_sorted
)
gts
=
gt_sorted
.
sum
()
intersection
=
gts
-
gt_sorted
.
float
().
cumsum
(
0
)
union
=
gts
+
(
1
-
gt_sorted
).
float
().
cumsum
(
0
)
jaccard
=
1.
-
intersection
/
union
if
p
>
1
:
# cover 1-pixel case
jaccard
[
1
:
p
]
=
jaccard
[
1
:
p
]
-
jaccard
[
0
:
-
1
]
return
jaccard
def
iou_binary
(
preds
,
labels
,
EMPTY
=
1.
,
ignore
=
None
,
per_image
=
True
):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if
not
per_image
:
preds
,
labels
=
(
preds
,),
(
labels
,)
ious
=
[]
for
pred
,
label
in
zip
(
preds
,
labels
):
intersection
=
((
label
==
1
)
&
(
pred
==
1
)).
sum
()
union
=
((
label
==
1
)
|
((
pred
==
1
)
&
(
label
!=
ignore
))).
sum
()
if
not
union
:
iou
=
EMPTY
else
:
iou
=
float
(
intersection
)
/
float
(
union
)
ious
.
append
(
iou
)
iou
=
mean
(
ious
)
# mean accross images if per_image
return
100
*
iou
def
iou
(
preds
,
labels
,
C
,
EMPTY
=
1.
,
ignore
=
None
,
per_image
=
False
):
"""
Array of IoU for each (non ignored) class
"""
if
not
per_image
:
preds
,
labels
=
(
preds
,),
(
labels
,)
ious
=
[]
for
pred
,
label
in
zip
(
preds
,
labels
):
iou
=
[]
for
i
in
range
(
C
):
if
i
!=
ignore
:
# The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection
=
((
label
==
i
)
&
(
pred
==
i
)).
sum
()
union
=
((
label
==
i
)
|
((
pred
==
i
)
&
(
label
!=
ignore
))).
sum
()
if
not
union
:
iou
.
append
(
EMPTY
)
else
:
iou
.
append
(
float
(
intersection
)
/
float
(
union
))
ious
.
append
(
iou
)
ious
=
[
mean
(
iou
)
for
iou
in
zip
(
*
ious
)]
# mean accross images if per_image
return
100
*
np
.
array
(
ious
)
# --------------------------- BINARY LOSSES ---------------------------
def
lovasz_hinge
(
logits
,
labels
,
per_image
=
True
,
ignore
=
None
):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if
per_image
:
loss
=
mean
(
lovasz_hinge_flat
(
*
flatten_binary_scores
(
log
.
unsqueeze
(
0
),
lab
.
unsqueeze
(
0
),
ignore
))
for
log
,
lab
in
zip
(
logits
,
labels
))
else
:
loss
=
lovasz_hinge_flat
(
*
flatten_binary_scores
(
logits
,
labels
,
ignore
))
return
loss
def
lovasz_hinge_flat
(
logits
,
labels
):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if
len
(
labels
)
==
0
:
# only void pixels, the gradients should be 0
return
logits
.
sum
()
*
0.
signs
=
2.
*
labels
.
float
()
-
1.
errors
=
(
1.
-
logits
*
Variable
(
signs
))
errors_sorted
,
perm
=
torch
.
sort
(
errors
,
dim
=
0
,
descending
=
True
)
perm
=
perm
.
data
gt_sorted
=
labels
[
perm
]
grad
=
lovasz_grad
(
gt_sorted
)
loss
=
torch
.
dot
(
F
.
relu
(
errors_sorted
),
Variable
(
grad
))
return
loss
def
flatten_binary_scores
(
scores
,
labels
,
ignore
=
None
):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores
=
scores
.
view
(
-
1
)
labels
=
labels
.
view
(
-
1
)
if
ignore
is
None
:
return
scores
,
labels
valid
=
(
labels
!=
ignore
)
vscores
=
scores
[
valid
]
vlabels
=
labels
[
valid
]
return
vscores
,
vlabels
class
StableBCELoss
(
torch
.
nn
.
modules
.
Module
):
def
__init__
(
self
):
super
(
StableBCELoss
,
self
).
__init__
()
def
forward
(
self
,
input
,
target
):
neg_abs
=
-
input
.
abs
()
loss
=
input
.
clamp
(
min
=
0
)
-
input
*
target
+
(
1
+
neg_abs
.
exp
()).
log
()
return
loss
.
mean
()
def
binary_xloss
(
logits
,
labels
,
ignore
=
None
):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits
,
labels
=
flatten_binary_scores
(
logits
,
labels
,
ignore
)
loss
=
StableBCELoss
()(
logits
,
Variable
(
labels
.
float
()))
return
loss
# --------------------------- MULTICLASS LOSSES ---------------------------
def
lovasz_softmax
(
probas
,
labels
,
classes
=
'present'
,
per_image
=
False
,
ignore
=
None
):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if
per_image
:
loss
=
mean
(
lovasz_softmax_flat
(
*
flatten_probas
(
prob
.
unsqueeze
(
0
),
lab
.
unsqueeze
(
0
),
ignore
),
classes
=
classes
)
for
prob
,
lab
in
zip
(
probas
,
labels
))
else
:
with
autocast
(
False
):
loss
=
lovasz_softmax_flat
(
*
flatten_probas
(
probas
,
labels
,
ignore
),
classes
=
classes
)
return
loss
def
lovasz_softmax_flat
(
probas
,
labels
,
classes
=
'present'
):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if
probas
.
numel
()
==
0
:
# only void pixels, the gradients should be 0
return
probas
*
0.
C
=
probas
.
size
(
1
)
losses
=
[]
class_to_sum
=
list
(
range
(
C
))
if
classes
in
[
'all'
,
'present'
]
else
classes
for
c
in
class_to_sum
:
fg
=
(
labels
==
c
).
float
()
# foreground for class c
if
(
classes
is
'present'
and
fg
.
sum
()
==
0
):
continue
if
C
==
1
:
if
len
(
classes
)
>
1
:
raise
ValueError
(
'Sigmoid output possible only with 1 class'
)
class_pred
=
probas
[:,
0
]
else
:
class_pred
=
probas
[:,
c
]
errors
=
(
Variable
(
fg
)
-
class_pred
).
abs
()
errors_sorted
,
perm
=
torch
.
sort
(
errors
,
0
,
descending
=
True
)
perm
=
perm
.
data
fg_sorted
=
fg
[
perm
]
losses
.
append
(
torch
.
dot
(
errors_sorted
,
Variable
(
lovasz_grad
(
fg_sorted
))))
return
mean
(
losses
)
def
flatten_probas
(
probas
,
labels
,
ignore
=
None
):
"""
Flattens predictions in the batch
"""
if
probas
.
dim
()
==
2
:
if
ignore
is
not
None
:
valid
=
(
labels
!=
ignore
)
probas
=
probas
[
valid
]
labels
=
labels
[
valid
]
return
probas
,
labels
elif
probas
.
dim
()
==
3
:
# assumes output of a sigmoid layer
B
,
H
,
W
=
probas
.
size
()
probas
=
probas
.
view
(
B
,
1
,
H
,
W
)
elif
probas
.
dim
()
==
5
:
#3D segmentation
B
,
C
,
L
,
H
,
W
=
probas
.
size
()
probas
=
probas
.
contiguous
().
view
(
B
,
C
,
L
,
H
*
W
)
B
,
C
,
H
,
W
=
probas
.
size
()
probas
=
probas
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
C
)
# B * H * W, C = P, C
labels
=
labels
.
view
(
-
1
)
if
ignore
is
None
:
return
probas
,
labels
valid
=
(
labels
!=
ignore
)
vprobas
=
probas
[
valid
.
nonzero
().
squeeze
()]
vlabels
=
labels
[
valid
]
return
vprobas
,
vlabels
def
xloss
(
logits
,
labels
,
ignore
=
None
):
"""
Cross entropy loss
"""
return
F
.
cross_entropy
(
logits
,
Variable
(
labels
),
ignore_index
=
255
)
def
jaccard_loss
(
probas
,
labels
,
ignore
=
None
,
smooth
=
100
,
bk_class
=
None
):
"""
Something wrong with this loss
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
vprobas
,
vlabels
=
flatten_probas
(
probas
,
labels
,
ignore
)
true_1_hot
=
torch
.
eye
(
vprobas
.
shape
[
1
])[
vlabels
]
if
bk_class
:
one_hot_assignment
=
torch
.
ones_like
(
vlabels
)
one_hot_assignment
[
vlabels
==
bk_class
]
=
0
one_hot_assignment
=
one_hot_assignment
.
float
().
unsqueeze
(
1
)
true_1_hot
=
true_1_hot
*
one_hot_assignment
true_1_hot
=
true_1_hot
.
to
(
vprobas
.
device
)
intersection
=
torch
.
sum
(
vprobas
*
true_1_hot
)
cardinality
=
torch
.
sum
(
vprobas
+
true_1_hot
)
loss
=
(
intersection
+
smooth
/
(
cardinality
-
intersection
+
smooth
)).
mean
()
return
(
1
-
loss
)
*
smooth
def
hinge_jaccard_loss
(
probas
,
labels
,
ignore
=
None
,
classes
=
'present'
,
hinge
=
0.1
,
smooth
=
100
):
"""
Multi-class Hinge Jaccard loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
ignore: void class labels
"""
vprobas
,
vlabels
=
flatten_probas
(
probas
,
labels
,
ignore
)
C
=
vprobas
.
size
(
1
)
losses
=
[]
class_to_sum
=
list
(
range
(
C
))
if
classes
in
[
'all'
,
'present'
]
else
classes
for
c
in
class_to_sum
:
if
c
in
vlabels
:
c_sample_ind
=
vlabels
==
c
cprobas
=
vprobas
[
c_sample_ind
,:]
non_c_ind
=
np
.
array
([
a
for
a
in
class_to_sum
if
a
!=
c
])
class_pred
=
cprobas
[:,
c
]
max_non_class_pred
=
torch
.
max
(
cprobas
[:,
non_c_ind
],
dim
=
1
)[
0
]
TP
=
torch
.
sum
(
torch
.
clamp
(
class_pred
-
max_non_class_pred
,
max
=
hinge
)
+
1.
)
+
smooth
FN
=
torch
.
sum
(
torch
.
clamp
(
max_non_class_pred
-
class_pred
,
min
=
-
hinge
)
+
hinge
)
if
(
~
c_sample_ind
).
sum
()
==
0
:
FP
=
0
else
:
nonc_probas
=
vprobas
[
~
c_sample_ind
,:]
class_pred
=
nonc_probas
[:,
c
]
max_non_class_pred
=
torch
.
max
(
nonc_probas
[:,
non_c_ind
],
dim
=
1
)[
0
]
FP
=
torch
.
sum
(
torch
.
clamp
(
class_pred
-
max_non_class_pred
,
max
=
hinge
)
+
1.
)
losses
.
append
(
1
-
TP
/
(
TP
+
FP
+
FN
))
if
len
(
losses
)
==
0
:
return
0
return
mean
(
losses
)
# --------------------------- HELPER FUNCTIONS ---------------------------
def
isnan
(
x
):
return
x
!=
x
def
mean
(
l
,
ignore_nan
=
False
,
empty
=
0
):
"""
nanmean compatible with generators.
"""
l
=
iter
(
l
)
if
ignore_nan
:
l
=
ifilterfalse
(
isnan
,
l
)
try
:
n
=
1
acc
=
next
(
l
)
except
StopIteration
:
if
empty
==
'raise'
:
raise
ValueError
(
'Empty mean'
)
return
empty
for
n
,
v
in
enumerate
(
l
,
2
):
acc
+=
v
if
n
==
1
:
return
acc
return
acc
/
n
projects/mmdet3d_plugin/models/losses/semkitti_loss.py
0 → 100644
View file @
3b8d508a
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
# from mmcv.runner import BaseModule, force_fp32
from
torch.cuda.amp
import
autocast
semantic_kitti_class_frequencies
=
np
.
array
(
[
5.41773033e09
,
1.57835390e07
,
1.25136000e05
,
1.18809000e05
,
6.46799000e05
,
8.21951000e05
,
2.62978000e05
,
2.83696000e05
,
2.04750000e05
,
6.16887030e07
,
4.50296100e06
,
4.48836500e07
,
2.26992300e06
,
5.68402180e07
,
1.57196520e07
,
1.58442623e08
,
2.06162300e06
,
3.69705220e07
,
1.15198800e06
,
3.34146000e05
,
]
)
kitti_class_names
=
[
"empty"
,
"car"
,
"bicycle"
,
"motorcycle"
,
"truck"
,
"other-vehicle"
,
"person"
,
"bicyclist"
,
"motorcyclist"
,
"road"
,
"parking"
,
"sidewalk"
,
"other-ground"
,
"building"
,
"fence"
,
"vegetation"
,
"trunk"
,
"terrain"
,
"pole"
,
"traffic-sign"
,
]
def
inverse_sigmoid
(
x
,
sign
=
'A'
):
x
=
x
.
to
(
torch
.
float32
)
while
x
>=
1
-
1e-5
:
x
=
x
-
1e-5
while
x
<
1e-5
:
x
=
x
+
1e-5
return
-
torch
.
log
((
1
/
x
)
-
1
)
def
KL_sep
(
p
,
target
):
"""
KL divergence on nonzeros classes
"""
nonzeros
=
target
!=
0
nonzero_p
=
p
[
nonzeros
]
kl_term
=
F
.
kl_div
(
torch
.
log
(
nonzero_p
),
target
[
nonzeros
],
reduction
=
"sum"
)
return
kl_term
def
geo_scal_loss
(
pred
,
ssc_target
,
ignore_index
=
255
,
non_empty_idx
=
0
):
# Get softmax probabilities
pred
=
F
.
softmax
(
pred
,
dim
=
1
)
# Compute empty and nonempty probabilities
empty_probs
=
pred
[:,
non_empty_idx
]
nonempty_probs
=
1
-
empty_probs
# Remove unknown voxels
mask
=
ssc_target
!=
ignore_index
nonempty_target
=
ssc_target
!=
non_empty_idx
nonempty_target
=
nonempty_target
[
mask
].
float
()
nonempty_probs
=
nonempty_probs
[
mask
]
empty_probs
=
empty_probs
[
mask
]
eps
=
1e-5
intersection
=
(
nonempty_target
*
nonempty_probs
).
sum
()
precision
=
intersection
/
(
nonempty_probs
.
sum
()
+
eps
)
recall
=
intersection
/
(
nonempty_target
.
sum
()
+
eps
)
spec
=
((
1
-
nonempty_target
)
*
(
empty_probs
)).
sum
()
/
((
1
-
nonempty_target
).
sum
()
+
eps
)
with
autocast
(
False
):
return
(
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
precision
,
'A'
),
torch
.
ones_like
(
precision
))
+
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
recall
,
'B'
),
torch
.
ones_like
(
recall
))
+
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
spec
,
'C'
),
torch
.
ones_like
(
spec
))
)
def
sem_scal_loss
(
pred_
,
ssc_target
,
ignore_index
=
255
):
# Get softmax probabilities
with
autocast
(
False
):
pred
=
F
.
softmax
(
pred_
,
dim
=
1
)
# (B, n_class, Dx, Dy, Dz)
loss
=
0
count
=
0
mask
=
ssc_target
!=
ignore_index
n_classes
=
pred
.
shape
[
1
]
begin
=
0
for
i
in
range
(
begin
,
n_classes
-
1
):
# Get probability of class i
p
=
pred
[:,
i
]
# (B, Dx, Dy, Dz)
# Remove unknown voxels
target_ori
=
ssc_target
# (B, Dx, Dy, Dz)
p
=
p
[
mask
]
target
=
ssc_target
[
mask
]
completion_target
=
torch
.
ones_like
(
target
)
completion_target
[
target
!=
i
]
=
0
completion_target_ori
=
torch
.
ones_like
(
target_ori
).
float
()
completion_target_ori
[
target_ori
!=
i
]
=
0
if
torch
.
sum
(
completion_target
)
>
0
:
count
+=
1.0
nominator
=
torch
.
sum
(
p
*
completion_target
)
loss_class
=
0
if
torch
.
sum
(
p
)
>
0
:
precision
=
nominator
/
(
torch
.
sum
(
p
)
+
1e-5
)
loss_precision
=
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
precision
,
'D'
),
torch
.
ones_like
(
precision
)
)
loss_class
+=
loss_precision
if
torch
.
sum
(
completion_target
)
>
0
:
recall
=
nominator
/
(
torch
.
sum
(
completion_target
)
+
1e-5
)
# loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
loss_recall
=
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
recall
,
'E'
),
torch
.
ones_like
(
recall
))
loss_class
+=
loss_recall
if
torch
.
sum
(
1
-
completion_target
)
>
0
:
specificity
=
torch
.
sum
((
1
-
p
)
*
(
1
-
completion_target
))
/
(
torch
.
sum
(
1
-
completion_target
)
+
1e-5
)
loss_specificity
=
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
specificity
,
'F'
),
torch
.
ones_like
(
specificity
)
)
loss_class
+=
loss_specificity
loss
+=
loss_class
# print(i, loss_class, loss_recall, loss_specificity)
l
=
loss
/
count
if
torch
.
isnan
(
l
):
from
IPython
import
embed
embed
()
exit
()
return
l
def
CE_ssc_loss
(
pred
,
target
,
class_weights
=
None
,
ignore_index
=
255
):
"""
:param: prediction: the predicted tensor, must be [BS, C, ...]
"""
criterion
=
nn
.
CrossEntropyLoss
(
weight
=
class_weights
,
ignore_index
=
ignore_index
,
reduction
=
"mean"
)
# from IPython import embed
# embed()
# exit()
with
autocast
(
False
):
loss
=
criterion
(
pred
,
target
.
long
())
return
loss
def
vel_loss
(
pred
,
gt
):
with
autocast
(
False
):
return
F
.
l1_loss
(
pred
,
gt
)
projects/mmdet3d_plugin/models/model_utils/__init__.py
0 → 100644
View file @
3b8d508a
from
.depthnet
import
DepthNet
__all__
=
[
'DepthNet'
]
\ No newline at end of file
projects/mmdet3d_plugin/models/model_utils/depthnet.py
0 → 100644
View file @
3b8d508a
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models.backbones.resnet
import
BasicBlock
from
mmcv.cnn
import
build_conv_layer
from
torch.cuda.amp.autocast_mode
import
autocast
from
torch.utils.checkpoint
import
checkpoint
class
_ASPPModule
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
planes
,
kernel_size
,
padding
,
dilation
,
BatchNorm
):
super
(
_ASPPModule
,
self
).
__init__
()
self
.
atrous_conv
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
False
)
self
.
bn
=
BatchNorm
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
_init_weight
()
def
forward
(
self
,
x
):
x
=
self
.
atrous_conv
(
x
)
x
=
self
.
bn
(
x
)
return
self
.
relu
(
x
)
def
_init_weight
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
class
ASPP
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
mid_channels
=
256
,
BatchNorm
=
nn
.
BatchNorm2d
):
super
(
ASPP
,
self
).
__init__
()
dilations
=
[
1
,
6
,
12
,
18
]
self
.
aspp1
=
_ASPPModule
(
inplanes
,
mid_channels
,
1
,
padding
=
0
,
dilation
=
dilations
[
0
],
BatchNorm
=
BatchNorm
)
self
.
aspp2
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
1
],
dilation
=
dilations
[
1
],
BatchNorm
=
BatchNorm
)
self
.
aspp3
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
2
],
dilation
=
dilations
[
2
],
BatchNorm
=
BatchNorm
)
self
.
aspp4
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
3
],
dilation
=
dilations
[
3
],
BatchNorm
=
BatchNorm
)
self
.
global_avg_pool
=
nn
.
Sequential
(
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
nn
.
Conv2d
(
inplanes
,
mid_channels
,
1
,
stride
=
1
,
bias
=
False
),
BatchNorm
(
mid_channels
),
nn
.
ReLU
(),
)
self
.
conv1
=
nn
.
Conv2d
(
int
(
mid_channels
*
5
),
inplanes
,
1
,
bias
=
False
)
self
.
bn1
=
BatchNorm
(
inplanes
)
self
.
relu
=
nn
.
ReLU
()
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
_init_weight
()
def
forward
(
self
,
x
):
"""
Args:
x: (B*N, C, fH, fW)
Returns:
x: (B*N, C, fH, fW)
"""
x1
=
self
.
aspp1
(
x
)
x2
=
self
.
aspp2
(
x
)
x3
=
self
.
aspp3
(
x
)
x4
=
self
.
aspp4
(
x
)
x5
=
self
.
global_avg_pool
(
x
)
x5
=
F
.
interpolate
(
x5
,
size
=
x4
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
torch
.
cat
((
x1
,
x2
,
x3
,
x4
,
x5
),
dim
=
1
)
# (B*N, 5*C', fH, fW)
x
=
self
.
conv1
(
x
)
# (B*N, C, fH, fW)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
return
self
.
dropout
(
x
)
def
_init_weight
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
ReLU
,
drop
=
0.0
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
drop1
=
nn
.
Dropout
(
drop
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop2
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
"""
Args:
x: (B*N_views, 27)
Returns:
x: (B*N_views, C)
"""
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop2
(
x
)
return
x
class
SELayer
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
act_layer
=
nn
.
ReLU
,
gate_layer
=
nn
.
Sigmoid
):
super
().
__init__
()
self
.
conv_reduce
=
nn
.
Conv2d
(
channels
,
channels
,
1
,
bias
=
True
)
self
.
act1
=
act_layer
()
self
.
conv_expand
=
nn
.
Conv2d
(
channels
,
channels
,
1
,
bias
=
True
)
self
.
gate
=
gate_layer
()
def
forward
(
self
,
x
,
x_se
):
"""
Args:
x: (B*N_views, C_mid, fH, fW)
x_se: (B*N_views, C_mid, 1, 1)
Returns:
x: (B*N_views, C_mid, fH, fW)
"""
x_se
=
self
.
conv_reduce
(
x_se
)
# (B*N_views, C_mid, 1, 1)
x_se
=
self
.
act1
(
x_se
)
# (B*N_views, C_mid, 1, 1)
x_se
=
self
.
conv_expand
(
x_se
)
# (B*N_views, C_mid, 1, 1)
return
x
*
self
.
gate
(
x_se
)
# (B*N_views, C_mid, fH, fW)
class
DepthNet
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
context_channels
,
depth_channels
,
use_dcn
=
True
,
use_aspp
=
True
,
with_cp
=
False
,
stereo
=
False
,
bias
=
0.0
,
aspp_mid_channels
=-
1
):
super
(
DepthNet
,
self
).
__init__
()
self
.
reduce_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
# 生成context feature
self
.
context_conv
=
nn
.
Conv2d
(
mid_channels
,
context_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
bn
=
nn
.
BatchNorm1d
(
27
)
self
.
depth_mlp
=
Mlp
(
in_features
=
27
,
hidden_features
=
mid_channels
,
out_features
=
mid_channels
)
self
.
depth_se
=
SELayer
(
channels
=
mid_channels
)
# NOTE: add camera-aware
self
.
context_mlp
=
Mlp
(
in_features
=
27
,
hidden_features
=
mid_channels
,
out_features
=
mid_channels
)
self
.
context_se
=
SELayer
(
channels
=
mid_channels
)
# NOTE: add camera-aware
depth_conv_input_channels
=
mid_channels
downsample
=
None
if
stereo
:
depth_conv_input_channels
+=
depth_channels
downsample
=
nn
.
Conv2d
(
depth_conv_input_channels
,
mid_channels
,
1
,
1
,
0
)
cost_volumn_net
=
[]
for
stage
in
range
(
int
(
2
)):
cost_volumn_net
.
extend
([
nn
.
Conv2d
(
depth_channels
,
depth_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
BatchNorm2d
(
depth_channels
)])
self
.
cost_volumn_net
=
nn
.
Sequential
(
*
cost_volumn_net
)
self
.
bias
=
bias
# 3个残差blocks
depth_conv_list
=
[
BasicBlock
(
depth_conv_input_channels
,
mid_channels
,
downsample
=
downsample
),
BasicBlock
(
mid_channels
,
mid_channels
),
BasicBlock
(
mid_channels
,
mid_channels
)]
if
use_aspp
:
if
aspp_mid_channels
<
0
:
aspp_mid_channels
=
mid_channels
depth_conv_list
.
append
(
ASPP
(
mid_channels
,
aspp_mid_channels
))
if
use_dcn
:
depth_conv_list
.
append
(
build_conv_layer
(
cfg
=
dict
(
type
=
'DCN'
,
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
padding
=
1
,
groups
=
4
,
im2col_step
=
128
,
)))
depth_conv_list
.
append
(
nn
.
Conv2d
(
mid_channels
,
depth_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
self
.
depth_conv
=
nn
.
Sequential
(
*
depth_conv_list
)
self
.
with_cp
=
with_cp
self
.
depth_channels
=
depth_channels
# ----------------------------------------- 用于建立cost volume ----------------------------------
def
gen_grid
(
self
,
metas
,
B
,
N
,
D
,
H
,
W
,
hi
,
wi
):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
B: batchsize
N: N_views
D: D
H: fH_stereo
W: fW_stereo
hi: H_img
wi: W_img
Returns:
grid: (B*N_views, D*fH_stereo, fW_stereo, 2)
"""
frustum
=
metas
[
'frustum'
]
# (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
# 逆图像增广:
points
=
frustum
-
metas
[
'post_trans'
].
view
(
B
,
N
,
1
,
1
,
1
,
3
)
points
=
torch
.
inverse
(
metas
[
'post_rots'
]).
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
)
\
.
matmul
(
points
.
unsqueeze
(
-
1
))
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
# (u, v, d) --> (du, dv, d)
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
points
=
torch
.
cat
(
(
points
[...,
:
2
,
:]
*
points
[...,
2
:
3
,
:],
points
[...,
2
:
3
,
:]),
5
)
# cur_pixel --> curr_camera --> prev_camera
rots
=
metas
[
'k2s_sensor'
][:,
:,
:
3
,
:
3
].
contiguous
()
trans
=
metas
[
'k2s_sensor'
][:,
:,
:
3
,
3
].
contiguous
()
combine
=
rots
.
matmul
(
torch
.
inverse
(
metas
[
'intrins'
]))
points
=
combine
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
)
points
+=
trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
1
)
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
neg_mask
=
points
[...,
2
,
0
]
<
1e-3
# prev_camera --> prev_pixel
points
=
metas
[
'intrins'
].
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
)
# (du, dv, d) --> (u, v) (B, N_views, D, fH_stereo, fW_stereo, 2, 1)
points
=
points
[...,
:
2
,
:]
/
points
[...,
2
:
3
,
:]
# 图像增广
points
=
metas
[
'post_rots'
][...,
:
2
,
:
2
].
view
(
B
,
N
,
1
,
1
,
1
,
2
,
2
).
matmul
(
points
).
squeeze
(
-
1
)
points
+=
metas
[
'post_trans'
][...,
:
2
].
view
(
B
,
N
,
1
,
1
,
1
,
2
)
# (B, N_views, D, fH_stereo, fW_stereo, 2)
px
=
points
[...,
0
]
/
(
wi
-
1.0
)
*
2.0
-
1.0
py
=
points
[...,
1
]
/
(
hi
-
1.0
)
*
2.0
-
1.0
px
[
neg_mask
]
=
-
2
py
[
neg_mask
]
=
-
2
grid
=
torch
.
stack
([
px
,
py
],
dim
=-
1
)
# (B, N_views, D, fH_stereo, fW_stereo, 2)
grid
=
grid
.
view
(
B
*
N
,
D
*
H
,
W
,
2
)
# (B*N_views, D*fH_stereo, fW_stereo, 2)
return
grid
def
calculate_cost_volumn
(
self
,
metas
):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
cost_volumn: (B*N_views, D, fH_stereo, fW_stereo)
"""
prev
,
curr
=
metas
[
'cv_feat_list'
]
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
group_size
=
4
_
,
c
,
hf
,
wf
=
curr
.
shape
#
hi
,
wi
=
hf
*
4
,
wf
*
4
# H_img, W_img
B
,
N
,
_
=
metas
[
'post_trans'
].
shape
D
,
H
,
W
,
_
=
metas
[
'frustum'
].
shape
grid
=
self
.
gen_grid
(
metas
,
B
,
N
,
D
,
H
,
W
,
hi
,
wi
).
to
(
curr
.
dtype
)
# (B*N_views, D*fH_stereo, fW_stereo, 2)
prev
=
prev
.
view
(
B
*
N
,
-
1
,
H
,
W
)
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
curr
=
curr
.
view
(
B
*
N
,
-
1
,
H
,
W
)
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
cost_volumn
=
0
# process in group wise to save memory
for
fid
in
range
(
curr
.
shape
[
1
]
//
group_size
):
# (B*N_views, group_size, fH_stereo, fW_stereo)
prev_curr
=
prev
[:,
fid
*
group_size
:(
fid
+
1
)
*
group_size
,
...]
wrap_prev
=
F
.
grid_sample
(
prev_curr
,
grid
,
align_corners
=
True
,
padding_mode
=
'zeros'
)
# (B*N_views, group_size, D*fH_stereo, fW_stereo)
# (B*N_views, group_size, fH_stereo, fW_stereo)
curr_tmp
=
curr
[:,
fid
*
group_size
:(
fid
+
1
)
*
group_size
,
...]
# (B*N_views, group_size, 1, fH_stereo, fW_stereo) - (B*N_views, group_size, D, fH_stereo, fW_stereo)
# --> (B*N_views, group_size, D, fH_stereo, fW_stereo)
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn_tmp
=
curr_tmp
.
unsqueeze
(
2
)
-
\
wrap_prev
.
view
(
B
*
N
,
-
1
,
D
,
H
,
W
)
cost_volumn_tmp
=
cost_volumn_tmp
.
abs
().
sum
(
dim
=
1
)
# (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn
+=
cost_volumn_tmp
# (B*N_views, D, fH_stereo, fW_stereo)
if
not
self
.
bias
==
0
:
invalid
=
wrap_prev
[:,
0
,
...].
view
(
B
*
N
,
D
,
H
,
W
)
==
0
cost_volumn
[
invalid
]
=
cost_volumn
[
invalid
]
+
self
.
bias
# matching cost --> prob
cost_volumn
=
-
cost_volumn
cost_volumn
=
cost_volumn
.
softmax
(
dim
=
1
)
return
cost_volumn
# ----------------------------------------- 用于建立cost volume --------------------------------------
def
forward
(
self
,
x
,
mlp_input
,
stereo_metas
=
None
):
"""
Args:
x: (B*N_views, C, fH, fW)
mlp_input: (B, N_views, 27)
stereo_metas: None or dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
x: (B*N_views, D+C_context, fH, fW)
"""
mlp_input
=
self
.
bn
(
mlp_input
.
reshape
(
-
1
,
mlp_input
.
shape
[
-
1
]))
# (B*N_views, 27)
x
=
self
.
reduce_conv
(
x
)
# (B*N_views, C_mid, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
context_se
=
self
.
context_mlp
(
mlp_input
)[...,
None
,
None
]
context
=
self
.
context_se
(
x
,
context_se
)
# (B*N_views, C_mid, fH, fW)
context
=
self
.
context_conv
(
context
)
# (B*N_views, C_context, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
depth_se
=
self
.
depth_mlp
(
mlp_input
)[...,
None
,
None
]
depth
=
self
.
depth_se
(
x
,
depth_se
)
# (B*N_views, C_mid, fH, fW)
if
not
stereo_metas
is
None
:
if
stereo_metas
[
'cv_feat_list'
][
0
]
is
None
:
BN
,
_
,
H
,
W
=
x
.
shape
scale_factor
=
float
(
stereo_metas
[
'downsample'
])
/
\
stereo_metas
[
'cv_downsample'
]
cost_volumn
=
\
torch
.
zeros
((
BN
,
self
.
depth_channels
,
int
(
H
*
scale_factor
),
int
(
W
*
scale_factor
))).
to
(
x
)
else
:
with
torch
.
no_grad
():
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn
=
self
.
calculate_cost_volumn
(
stereo_metas
)
# (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn
=
self
.
cost_volumn_net
(
cost_volumn
)
# (B*N_views, D, fH, fW)
depth
=
torch
.
cat
([
depth
,
cost_volumn
],
dim
=
1
)
# (B*N_views, C_mid+D, fH, fW)
if
self
.
with_cp
:
depth
=
checkpoint
(
self
.
depth_conv
,
depth
)
else
:
# 3*res blocks +ASPP/DCN + Conv(c_mid-->D)
depth
=
self
.
depth_conv
(
depth
)
# x: (B*N_views, C_mid, fH, fW) --> (B*N_views, D, fH, fW)
return
torch
.
cat
([
depth
,
context
],
dim
=
1
)
class
DepthAggregation
(
nn
.
Module
):
"""pixel cloud feature extraction."""
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
):
super
(
DepthAggregation
,
self
).
__init__
()
self
.
reduce_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
out_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
mid_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
),
# nn.BatchNorm3d(out_channels),
# nn.ReLU(inplace=True),
)
@
autocast
(
False
)
def
forward
(
self
,
x
):
x
=
checkpoint
(
self
.
reduce_conv
,
x
)
short_cut
=
x
x
=
checkpoint
(
self
.
conv
,
x
)
x
=
short_cut
+
x
x
=
self
.
out_conv
(
x
)
return
x
\ No newline at end of file
projects/mmdet3d_plugin/models/necks/__init__.py
0 → 100644
View file @
3b8d508a
from
.fpn
import
CustomFPN
from
.view_transformer
import
LSSViewTransformer
,
LSSViewTransformerBEVDepth
,
LSSViewTransformerBEVStereo
from
.lss_fpn
import
FPN_LSS
__all__
=
[
'CustomFPN'
,
'FPN_LSS'
,
'LSSViewTransformer'
,
'LSSViewTransformerBEVDepth'
,
'LSSViewTransformerBEVStereo'
]
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment