Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
lishj6
Sparsedrive
Commits
afe88104
Commit
afe88104
authored
Sep 05, 2025
by
lishj6
🏸
Browse files
init0905
parent
a48c4071
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3623 additions
and
0 deletions
+3623
-0
projects/mmdet3d_plugin/models/base_target.py
projects/mmdet3d_plugin/models/base_target.py
+49
-0
projects/mmdet3d_plugin/models/blocks.py
projects/mmdet3d_plugin/models/blocks.py
+393
-0
projects/mmdet3d_plugin/models/detection3d/__init__.py
projects/mmdet3d_plugin/models/detection3d/__init__.py
+9
-0
projects/mmdet3d_plugin/models/detection3d/decoder.py
projects/mmdet3d_plugin/models/detection3d/decoder.py
+107
-0
projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
...s/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
+300
-0
projects/mmdet3d_plugin/models/detection3d/detection3d_head.py
...cts/mmdet3d_plugin/models/detection3d/detection3d_head.py
+558
-0
projects/mmdet3d_plugin/models/detection3d/losses.py
projects/mmdet3d_plugin/models/detection3d/losses.py
+93
-0
projects/mmdet3d_plugin/models/detection3d/target.py
projects/mmdet3d_plugin/models/detection3d/target.py
+437
-0
projects/mmdet3d_plugin/models/grid_mask.py
projects/mmdet3d_plugin/models/grid_mask.py
+138
-0
projects/mmdet3d_plugin/models/instance_bank.py
projects/mmdet3d_plugin/models/instance_bank.py
+259
-0
projects/mmdet3d_plugin/models/map/__init__.py
projects/mmdet3d_plugin/models/map/__init__.py
+9
-0
projects/mmdet3d_plugin/models/map/decoder.py
projects/mmdet3d_plugin/models/map/decoder.py
+53
-0
projects/mmdet3d_plugin/models/map/loss.py
projects/mmdet3d_plugin/models/map/loss.py
+120
-0
projects/mmdet3d_plugin/models/map/map_blocks.py
projects/mmdet3d_plugin/models/map/map_blocks.py
+199
-0
projects/mmdet3d_plugin/models/map/match_cost.py
projects/mmdet3d_plugin/models/map/match_cost.py
+104
-0
projects/mmdet3d_plugin/models/map/target.py
projects/mmdet3d_plugin/models/map/target.py
+167
-0
projects/mmdet3d_plugin/models/motion/__init__.py
projects/mmdet3d_plugin/models/motion/__init__.py
+5
-0
projects/mmdet3d_plugin/models/motion/decoder.py
projects/mmdet3d_plugin/models/motion/decoder.py
+329
-0
projects/mmdet3d_plugin/models/motion/instance_queue.py
projects/mmdet3d_plugin/models/motion/instance_queue.py
+213
-0
projects/mmdet3d_plugin/models/motion/motion_blocks.py
projects/mmdet3d_plugin/models/motion/motion_blocks.py
+81
-0
No files found.
projects/mmdet3d_plugin/models/base_target.py
0 → 100644
View file @
afe88104
from
abc
import
ABC
,
abstractmethod
__all__
=
[
"BaseTargetWithDenoising"
]
class
BaseTargetWithDenoising
(
ABC
):
def
__init__
(
self
,
num_dn_groups
=
0
,
num_temp_dn_groups
=
0
):
super
(
BaseTargetWithDenoising
,
self
).
__init__
()
self
.
num_dn_groups
=
num_dn_groups
self
.
num_temp_dn_groups
=
num_temp_dn_groups
self
.
dn_metas
=
None
@
abstractmethod
def
sample
(
self
,
cls_pred
,
box_pred
,
cls_target
,
box_target
):
"""
Perform Hungarian matching between predictions and ground truth,
returning the matched ground truth corresponding to the predictions
along with the corresponding regression weights.
"""
def
get_dn_anchors
(
self
,
cls_target
,
box_target
,
*
args
,
**
kwargs
):
"""
Generate noisy instances for the current frame, with a total of
'self.num_dn_groups' groups.
"""
return
None
def
update_dn
(
self
,
instance_feature
,
anchor
,
*
args
,
**
kwargs
):
"""
Insert the previously saved 'self.dn_metas' into the noisy instances
of the current frame.
"""
def
cache_dn
(
self
,
dn_instance_feature
,
dn_anchor
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
):
"""
Randomly save information for 'self.num_temp_dn_groups' groups of
temporal noisy instances to 'self.dn_metas'.
"""
if
self
.
num_temp_dn_groups
<
0
:
return
self
.
dn_metas
=
dict
(
dn_anchor
=
dn_anchor
[:,
:
self
.
num_temp_dn_groups
])
projects/mmdet3d_plugin/models/blocks.py
0 → 100644
View file @
afe88104
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.cuda.amp.autocast_mode
import
autocast
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
build_norm_layer
from
mmcv.runner.base_module
import
Sequential
,
BaseModule
from
mmcv.cnn.bricks.transformer
import
FFN
from
mmcv.utils
import
build_from_cfg
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
PLUGIN_LAYERS
,
FEEDFORWARD_NETWORK
,
)
try
:
from
..ops
import
deformable_aggregation_function
as
DAF
except
:
DAF
=
None
__all__
=
[
"DeformableFeatureAggregation"
,
"DenseDepthNet"
,
"AsymmetricFFN"
,
]
def
linear_relu_ln
(
embed_dims
,
in_loops
,
out_loops
,
input_dims
=
None
):
if
input_dims
is
None
:
input_dims
=
embed_dims
layers
=
[]
for
_
in
range
(
out_loops
):
for
_
in
range
(
in_loops
):
layers
.
append
(
Linear
(
input_dims
,
embed_dims
))
layers
.
append
(
nn
.
ReLU
(
inplace
=
True
))
input_dims
=
embed_dims
layers
.
append
(
nn
.
LayerNorm
(
embed_dims
))
return
layers
@
ATTENTION
.
register_module
()
class
DeformableFeatureAggregation
(
BaseModule
):
def
__init__
(
self
,
embed_dims
:
int
=
256
,
num_groups
:
int
=
8
,
num_levels
:
int
=
4
,
num_cams
:
int
=
6
,
proj_drop
:
float
=
0.0
,
attn_drop
:
float
=
0.0
,
kps_generator
:
dict
=
None
,
temporal_fusion_module
=
None
,
use_temporal_anchor_embed
=
True
,
use_deformable_func
=
False
,
use_camera_embed
=
False
,
residual_mode
=
"add"
,
):
super
(
DeformableFeatureAggregation
,
self
).
__init__
()
if
embed_dims
%
num_groups
!=
0
:
raise
ValueError
(
f
"embed_dims must be divisible by num_groups, "
f
"but got
{
embed_dims
}
and
{
num_groups
}
"
)
self
.
group_dims
=
int
(
embed_dims
/
num_groups
)
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_groups
=
num_groups
self
.
num_cams
=
num_cams
self
.
use_temporal_anchor_embed
=
use_temporal_anchor_embed
if
use_deformable_func
:
assert
DAF
is
not
None
,
"deformable_aggregation needs to be set up."
self
.
use_deformable_func
=
use_deformable_func
self
.
attn_drop
=
attn_drop
self
.
residual_mode
=
residual_mode
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
kps_generator
[
"embed_dims"
]
=
embed_dims
self
.
kps_generator
=
build_from_cfg
(
kps_generator
,
PLUGIN_LAYERS
)
self
.
num_pts
=
self
.
kps_generator
.
num_pts
if
temporal_fusion_module
is
not
None
:
if
"embed_dims"
not
in
temporal_fusion_module
:
temporal_fusion_module
[
"embed_dims"
]
=
embed_dims
self
.
temp_module
=
build_from_cfg
(
temporal_fusion_module
,
PLUGIN_LAYERS
)
else
:
self
.
temp_module
=
None
self
.
output_proj
=
Linear
(
embed_dims
,
embed_dims
)
if
use_camera_embed
:
self
.
camera_encoder
=
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
,
12
)
)
self
.
weights_fc
=
Linear
(
embed_dims
,
num_groups
*
num_levels
*
self
.
num_pts
)
else
:
self
.
camera_encoder
=
None
self
.
weights_fc
=
Linear
(
embed_dims
,
num_groups
*
num_cams
*
num_levels
*
self
.
num_pts
)
def
init_weight
(
self
):
constant_init
(
self
.
weights_fc
,
val
=
0.0
,
bias
=
0.0
)
xavier_init
(
self
.
output_proj
,
distribution
=
"uniform"
,
bias
=
0.0
)
def
forward
(
self
,
instance_feature
:
torch
.
Tensor
,
anchor
:
torch
.
Tensor
,
anchor_embed
:
torch
.
Tensor
,
feature_maps
:
List
[
torch
.
Tensor
],
metas
:
dict
,
**
kwargs
:
dict
,
):
bs
,
num_anchor
=
instance_feature
.
shape
[:
2
]
key_points
=
self
.
kps_generator
(
anchor
,
instance_feature
)
weights
=
self
.
_get_weights
(
instance_feature
,
anchor_embed
,
metas
)
if
self
.
use_deformable_func
:
points_2d
=
(
self
.
project_points
(
key_points
,
metas
[
"projection_mat"
],
metas
.
get
(
"image_wh"
),
)
.
permute
(
0
,
2
,
3
,
1
,
4
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_pts
,
self
.
num_cams
,
2
)
)
weights
=
(
weights
.
permute
(
0
,
1
,
4
,
2
,
3
,
5
)
.
contiguous
()
.
reshape
(
bs
,
num_anchor
,
self
.
num_pts
,
self
.
num_cams
,
self
.
num_levels
,
self
.
num_groups
,
)
)
features
=
DAF
(
*
feature_maps
,
points_2d
,
weights
).
reshape
(
bs
,
num_anchor
,
self
.
embed_dims
)
else
:
features
=
self
.
feature_sampling
(
feature_maps
,
key_points
,
metas
[
"projection_mat"
],
metas
.
get
(
"image_wh"
),
)
features
=
self
.
multi_view_level_fusion
(
features
,
weights
)
features
=
features
.
sum
(
dim
=
2
)
# fuse multi-point features
output
=
self
.
proj_drop
(
self
.
output_proj
(
features
))
if
self
.
residual_mode
==
"add"
:
output
=
output
+
instance_feature
elif
self
.
residual_mode
==
"cat"
:
output
=
torch
.
cat
([
output
,
instance_feature
],
dim
=-
1
)
return
output
def
_get_weights
(
self
,
instance_feature
,
anchor_embed
,
metas
=
None
):
bs
,
num_anchor
=
instance_feature
.
shape
[:
2
]
feature
=
instance_feature
+
anchor_embed
if
self
.
camera_encoder
is
not
None
:
camera_embed
=
self
.
camera_encoder
(
metas
[
"projection_mat"
][:,
:,
:
3
].
reshape
(
bs
,
self
.
num_cams
,
-
1
)
)
feature
=
feature
[:,
:,
None
]
+
camera_embed
[:,
None
]
weights
=
(
self
.
weights_fc
(
feature
)
.
reshape
(
bs
,
num_anchor
,
-
1
,
self
.
num_groups
)
.
softmax
(
dim
=-
2
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_cams
,
self
.
num_levels
,
self
.
num_pts
,
self
.
num_groups
,
)
)
if
self
.
training
and
self
.
attn_drop
>
0
:
mask
=
torch
.
rand
(
bs
,
num_anchor
,
self
.
num_cams
,
1
,
self
.
num_pts
,
1
)
mask
=
mask
.
to
(
device
=
weights
.
device
,
dtype
=
weights
.
dtype
)
weights
=
((
mask
>
self
.
attn_drop
)
*
weights
)
/
(
1
-
self
.
attn_drop
)
return
weights
@
staticmethod
def
project_points
(
key_points
,
projection_mat
,
image_wh
=
None
):
bs
,
num_anchor
,
num_pts
=
key_points
.
shape
[:
3
]
pts_extend
=
torch
.
cat
(
[
key_points
,
torch
.
ones_like
(
key_points
[...,
:
1
])],
dim
=-
1
)
points_2d
=
torch
.
matmul
(
projection_mat
[:,
:,
None
,
None
],
pts_extend
[:,
None
,
...,
None
]
).
squeeze
(
-
1
)
points_2d
=
points_2d
[...,
:
2
]
/
torch
.
clamp
(
points_2d
[...,
2
:
3
],
min
=
1e-5
)
if
image_wh
is
not
None
:
points_2d
=
points_2d
/
image_wh
[:,
:,
None
,
None
]
return
points_2d
@
staticmethod
def
feature_sampling
(
feature_maps
:
List
[
torch
.
Tensor
],
key_points
:
torch
.
Tensor
,
projection_mat
:
torch
.
Tensor
,
image_wh
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
num_levels
=
len
(
feature_maps
)
num_cams
=
feature_maps
[
0
].
shape
[
1
]
bs
,
num_anchor
,
num_pts
=
key_points
.
shape
[:
3
]
points_2d
=
DeformableFeatureAggregation
.
project_points
(
key_points
,
projection_mat
,
image_wh
)
points_2d
=
points_2d
*
2
-
1
points_2d
=
points_2d
.
flatten
(
end_dim
=
1
)
features
=
[]
for
fm
in
feature_maps
:
features
.
append
(
torch
.
nn
.
functional
.
grid_sample
(
fm
.
flatten
(
end_dim
=
1
),
points_2d
)
)
features
=
torch
.
stack
(
features
,
dim
=
1
)
features
=
features
.
reshape
(
bs
,
num_cams
,
num_levels
,
-
1
,
num_anchor
,
num_pts
).
permute
(
0
,
4
,
1
,
2
,
5
,
3
)
# bs, num_anchor, num_cams, num_levels, num_pts, embed_dims
return
features
def
multi_view_level_fusion
(
self
,
features
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
bs
,
num_anchor
=
weights
.
shape
[:
2
]
features
=
weights
[...,
None
]
*
features
.
reshape
(
features
.
shape
[:
-
1
]
+
(
self
.
num_groups
,
self
.
group_dims
)
)
features
=
features
.
sum
(
dim
=
2
).
sum
(
dim
=
2
)
features
=
features
.
reshape
(
bs
,
num_anchor
,
self
.
num_pts
,
self
.
embed_dims
)
return
features
@
PLUGIN_LAYERS
.
register_module
()
class
DenseDepthNet
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
num_depth_layers
=
1
,
equal_focal
=
100
,
max_depth
=
60
,
loss_weight
=
1.0
,
):
super
().
__init__
()
self
.
embed_dims
=
embed_dims
self
.
equal_focal
=
equal_focal
self
.
num_depth_layers
=
num_depth_layers
self
.
max_depth
=
max_depth
self
.
loss_weight
=
loss_weight
self
.
depth_layers
=
nn
.
ModuleList
()
for
i
in
range
(
num_depth_layers
):
self
.
depth_layers
.
append
(
nn
.
Conv2d
(
embed_dims
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
def
forward
(
self
,
feature_maps
,
focal
=
None
,
gt_depths
=
None
):
if
focal
is
None
:
focal
=
self
.
equal_focal
else
:
focal
=
focal
.
reshape
(
-
1
)
depths
=
[]
for
i
,
feat
in
enumerate
(
feature_maps
[:
self
.
num_depth_layers
]):
depth
=
self
.
depth_layers
[
i
](
feat
.
flatten
(
end_dim
=
1
).
float
()).
exp
()
depth
=
depth
.
transpose
(
0
,
-
1
)
*
focal
/
self
.
equal_focal
depth
=
depth
.
transpose
(
0
,
-
1
)
depths
.
append
(
depth
)
if
gt_depths
is
not
None
and
self
.
training
:
loss
=
self
.
loss
(
depths
,
gt_depths
)
return
loss
return
depths
def
loss
(
self
,
depth_preds
,
gt_depths
):
loss
=
0.0
for
pred
,
gt
in
zip
(
depth_preds
,
gt_depths
):
pred
=
pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
reshape
(
-
1
)
gt
=
gt
.
reshape
(
-
1
)
fg_mask
=
torch
.
logical_and
(
gt
>
0.0
,
torch
.
logical_not
(
torch
.
isnan
(
pred
))
)
gt
=
gt
[
fg_mask
]
pred
=
pred
[
fg_mask
]
pred
=
torch
.
clip
(
pred
,
0.0
,
self
.
max_depth
)
with
autocast
(
enabled
=
False
):
error
=
torch
.
abs
(
pred
-
gt
).
sum
()
_loss
=
(
error
/
max
(
1.0
,
len
(
gt
)
*
len
(
depth_preds
))
*
self
.
loss_weight
)
loss
=
loss
+
_loss
return
loss
@
FEEDFORWARD_NETWORK
.
register_module
()
class
AsymmetricFFN
(
BaseModule
):
def
__init__
(
self
,
in_channels
=
None
,
pre_norm
=
None
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
act_cfg
=
dict
(
type
=
"ReLU"
,
inplace
=
True
),
ffn_drop
=
0.0
,
dropout_layer
=
None
,
add_identity
=
True
,
init_cfg
=
None
,
**
kwargs
,
):
super
(
AsymmetricFFN
,
self
).
__init__
(
init_cfg
)
assert
num_fcs
>=
2
,
(
"num_fcs should be no less "
f
"than 2. got
{
num_fcs
}
."
)
self
.
in_channels
=
in_channels
self
.
pre_norm
=
pre_norm
self
.
embed_dims
=
embed_dims
self
.
feedforward_channels
=
feedforward_channels
self
.
num_fcs
=
num_fcs
self
.
act_cfg
=
act_cfg
self
.
activate
=
build_activation_layer
(
act_cfg
)
layers
=
[]
if
in_channels
is
None
:
in_channels
=
embed_dims
if
pre_norm
is
not
None
:
self
.
pre_norm
=
build_norm_layer
(
pre_norm
,
in_channels
)[
1
]
for
_
in
range
(
num_fcs
-
1
):
layers
.
append
(
Sequential
(
Linear
(
in_channels
,
feedforward_channels
),
self
.
activate
,
nn
.
Dropout
(
ffn_drop
),
)
)
in_channels
=
feedforward_channels
layers
.
append
(
Linear
(
feedforward_channels
,
embed_dims
))
layers
.
append
(
nn
.
Dropout
(
ffn_drop
))
self
.
layers
=
Sequential
(
*
layers
)
self
.
dropout_layer
=
(
build_dropout
(
dropout_layer
)
if
dropout_layer
else
torch
.
nn
.
Identity
()
)
self
.
add_identity
=
add_identity
if
self
.
add_identity
:
self
.
identity_fc
=
(
torch
.
nn
.
Identity
()
if
in_channels
==
embed_dims
else
Linear
(
self
.
in_channels
,
embed_dims
)
)
def
forward
(
self
,
x
,
identity
=
None
):
if
self
.
pre_norm
is
not
None
:
x
=
self
.
pre_norm
(
x
)
out
=
self
.
layers
(
x
)
if
not
self
.
add_identity
:
return
self
.
dropout_layer
(
out
)
if
identity
is
None
:
identity
=
x
identity
=
self
.
identity_fc
(
identity
)
return
identity
+
self
.
dropout_layer
(
out
)
projects/mmdet3d_plugin/models/detection3d/__init__.py
0 → 100644
View file @
afe88104
from
.decoder
import
SparseBox3DDecoder
from
.target
import
SparseBox3DTarget
from
.detection3d_blocks
import
(
SparseBox3DRefinementModule
,
SparseBox3DKeyPointsGenerator
,
SparseBox3DEncoder
,
)
from
.losses
import
SparseBox3DLoss
from
.detection3d_head
import
Sparse4DHead
projects/mmdet3d_plugin/models/detection3d/decoder.py
0 → 100644
View file @
afe88104
from
typing
import
Optional
import
torch
from
mmdet.core.bbox.builder
import
BBOX_CODERS
from
projects.mmdet3d_plugin.core.box3d
import
*
def
decode_box
(
box
):
yaw
=
torch
.
atan2
(
box
[...,
SIN_YAW
],
box
[...,
COS_YAW
])
box
=
torch
.
cat
(
[
box
[...,
[
X
,
Y
,
Z
]],
box
[...,
[
W
,
L
,
H
]].
exp
(),
yaw
[...,
None
],
box
[...,
VX
:],
],
dim
=-
1
,
)
return
box
@
BBOX_CODERS
.
register_module
()
class
SparseBox3DDecoder
(
object
):
def
__init__
(
self
,
num_output
:
int
=
300
,
score_threshold
:
Optional
[
float
]
=
None
,
sorted
:
bool
=
True
,
):
super
(
SparseBox3DDecoder
,
self
).
__init__
()
self
.
num_output
=
num_output
self
.
score_threshold
=
score_threshold
self
.
sorted
=
sorted
def
decode
(
self
,
cls_scores
,
box_preds
,
instance_id
=
None
,
quality
=
None
,
output_idx
=-
1
,
):
squeeze_cls
=
instance_id
is
not
None
cls_scores
=
cls_scores
[
output_idx
].
sigmoid
()
if
squeeze_cls
:
cls_scores
,
cls_ids
=
cls_scores
.
max
(
dim
=-
1
)
cls_scores
=
cls_scores
.
unsqueeze
(
dim
=-
1
)
box_preds
=
box_preds
[
output_idx
]
bs
,
num_pred
,
num_cls
=
cls_scores
.
shape
cls_scores
,
indices
=
cls_scores
.
flatten
(
start_dim
=
1
).
topk
(
self
.
num_output
,
dim
=
1
,
sorted
=
self
.
sorted
)
if
not
squeeze_cls
:
cls_ids
=
indices
%
num_cls
if
self
.
score_threshold
is
not
None
:
mask
=
cls_scores
>=
self
.
score_threshold
if
quality
[
output_idx
]
is
None
:
quality
=
None
if
quality
is
not
None
:
centerness
=
quality
[
output_idx
][...,
CNS
]
centerness
=
torch
.
gather
(
centerness
,
1
,
indices
//
num_cls
)
cls_scores_origin
=
cls_scores
.
clone
()
cls_scores
*=
centerness
.
sigmoid
()
cls_scores
,
idx
=
torch
.
sort
(
cls_scores
,
dim
=
1
,
descending
=
True
)
if
not
squeeze_cls
:
cls_ids
=
torch
.
gather
(
cls_ids
,
1
,
idx
)
if
self
.
score_threshold
is
not
None
:
mask
=
torch
.
gather
(
mask
,
1
,
idx
)
indices
=
torch
.
gather
(
indices
,
1
,
idx
)
output
=
[]
for
i
in
range
(
bs
):
category_ids
=
cls_ids
[
i
]
if
squeeze_cls
:
category_ids
=
category_ids
[
indices
[
i
]]
scores
=
cls_scores
[
i
]
box
=
box_preds
[
i
,
indices
[
i
]
//
num_cls
]
if
self
.
score_threshold
is
not
None
:
category_ids
=
category_ids
[
mask
[
i
]]
scores
=
scores
[
mask
[
i
]]
box
=
box
[
mask
[
i
]]
if
quality
is
not
None
:
scores_origin
=
cls_scores_origin
[
i
]
if
self
.
score_threshold
is
not
None
:
scores_origin
=
scores_origin
[
mask
[
i
]]
box
=
decode_box
(
box
)
output
.
append
(
{
"boxes_3d"
:
box
.
cpu
(),
"scores_3d"
:
scores
.
cpu
(),
"labels_3d"
:
category_ids
.
cpu
(),
}
)
if
quality
is
not
None
:
output
[
-
1
][
"cls_scores"
]
=
scores_origin
.
cpu
()
if
instance_id
is
not
None
:
ids
=
instance_id
[
i
,
indices
[
i
]]
if
self
.
score_threshold
is
not
None
:
ids
=
ids
[
mask
[
i
]]
output
[
-
1
][
"instance_ids"
]
=
ids
return
output
projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
0 → 100644
View file @
afe88104
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
mmcv.cnn
import
Linear
,
Scale
,
bias_init_with_prob
from
mmcv.runner.base_module
import
Sequential
,
BaseModule
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.registry
import
(
PLUGIN_LAYERS
,
POSITIONAL_ENCODING
,
)
from
projects.mmdet3d_plugin.core.box3d
import
*
from
..blocks
import
linear_relu_ln
__all__
=
[
"SparseBox3DRefinementModule"
,
"SparseBox3DKeyPointsGenerator"
,
"SparseBox3DEncoder"
,
]
@
POSITIONAL_ENCODING
.
register_module
()
class
SparseBox3DEncoder
(
BaseModule
):
def
__init__
(
self
,
embed_dims
,
vel_dims
=
3
,
mode
=
"add"
,
output_fc
=
True
,
in_loops
=
1
,
out_loops
=
2
,
):
super
().
__init__
()
assert
mode
in
[
"add"
,
"cat"
]
self
.
embed_dims
=
embed_dims
self
.
vel_dims
=
vel_dims
self
.
mode
=
mode
def
embedding_layer
(
input_dims
,
output_dims
):
return
nn
.
Sequential
(
*
linear_relu_ln
(
output_dims
,
in_loops
,
out_loops
,
input_dims
)
)
if
not
isinstance
(
embed_dims
,
(
list
,
tuple
)):
embed_dims
=
[
embed_dims
]
*
5
self
.
pos_fc
=
embedding_layer
(
3
,
embed_dims
[
0
])
self
.
size_fc
=
embedding_layer
(
3
,
embed_dims
[
1
])
self
.
yaw_fc
=
embedding_layer
(
2
,
embed_dims
[
2
])
if
vel_dims
>
0
:
self
.
vel_fc
=
embedding_layer
(
self
.
vel_dims
,
embed_dims
[
3
])
if
output_fc
:
self
.
output_fc
=
embedding_layer
(
embed_dims
[
-
1
],
embed_dims
[
-
1
])
else
:
self
.
output_fc
=
None
def
forward
(
self
,
box_3d
:
torch
.
Tensor
):
pos_feat
=
self
.
pos_fc
(
box_3d
[...,
[
X
,
Y
,
Z
]])
size_feat
=
self
.
size_fc
(
box_3d
[...,
[
W
,
L
,
H
]])
yaw_feat
=
self
.
yaw_fc
(
box_3d
[...,
[
SIN_YAW
,
COS_YAW
]])
if
self
.
mode
==
"add"
:
output
=
pos_feat
+
size_feat
+
yaw_feat
elif
self
.
mode
==
"cat"
:
output
=
torch
.
cat
([
pos_feat
,
size_feat
,
yaw_feat
],
dim
=-
1
)
if
self
.
vel_dims
>
0
:
vel_feat
=
self
.
vel_fc
(
box_3d
[...,
VX
:
VX
+
self
.
vel_dims
])
if
self
.
mode
==
"add"
:
output
=
output
+
vel_feat
elif
self
.
mode
==
"cat"
:
output
=
torch
.
cat
([
output
,
vel_feat
],
dim
=-
1
)
if
self
.
output_fc
is
not
None
:
output
=
self
.
output_fc
(
output
)
return
output
@
PLUGIN_LAYERS
.
register_module
()
class
SparseBox3DRefinementModule
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
output_dim
=
11
,
num_cls
=
10
,
normalize_yaw
=
False
,
refine_yaw
=
False
,
with_cls_branch
=
True
,
with_quality_estimation
=
False
,
):
super
(
SparseBox3DRefinementModule
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
output_dim
=
output_dim
self
.
num_cls
=
num_cls
self
.
normalize_yaw
=
normalize_yaw
self
.
refine_yaw
=
refine_yaw
self
.
refine_state
=
[
X
,
Y
,
Z
,
W
,
L
,
H
]
if
self
.
refine_yaw
:
self
.
refine_state
+=
[
SIN_YAW
,
COS_YAW
]
self
.
layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
2
,
2
),
Linear
(
self
.
embed_dims
,
self
.
output_dim
),
Scale
([
1.0
]
*
self
.
output_dim
),
)
self
.
with_cls_branch
=
with_cls_branch
if
with_cls_branch
:
self
.
cls_layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
self
.
embed_dims
,
self
.
num_cls
),
)
self
.
with_quality_estimation
=
with_quality_estimation
if
with_quality_estimation
:
self
.
quality_layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
self
.
embed_dims
,
2
),
)
def
init_weight
(
self
):
if
self
.
with_cls_branch
:
bias_init
=
bias_init_with_prob
(
0.01
)
nn
.
init
.
constant_
(
self
.
cls_layers
[
-
1
].
bias
,
bias_init
)
def
forward
(
self
,
instance_feature
:
torch
.
Tensor
,
anchor
:
torch
.
Tensor
,
anchor_embed
:
torch
.
Tensor
,
time_interval
:
torch
.
Tensor
=
1.0
,
return_cls
=
True
,
):
feature
=
instance_feature
+
anchor_embed
output
=
self
.
layers
(
feature
)
output
[...,
self
.
refine_state
]
=
(
output
[...,
self
.
refine_state
]
+
anchor
[...,
self
.
refine_state
]
)
if
self
.
normalize_yaw
:
output
[...,
[
SIN_YAW
,
COS_YAW
]]
=
torch
.
nn
.
functional
.
normalize
(
output
[...,
[
SIN_YAW
,
COS_YAW
]],
dim
=-
1
)
if
self
.
output_dim
>
8
:
if
not
isinstance
(
time_interval
,
torch
.
Tensor
):
time_interval
=
instance_feature
.
new_tensor
(
time_interval
)
translation
=
torch
.
transpose
(
output
[...,
VX
:],
0
,
-
1
)
velocity
=
torch
.
transpose
(
translation
/
time_interval
,
0
,
-
1
)
output
[...,
VX
:]
=
velocity
+
anchor
[...,
VX
:]
if
return_cls
:
assert
self
.
with_cls_branch
,
"Without classification layers !!!"
cls
=
self
.
cls_layers
(
instance_feature
)
else
:
cls
=
None
if
return_cls
and
self
.
with_quality_estimation
:
quality
=
self
.
quality_layers
(
feature
)
else
:
quality
=
None
return
output
,
cls
,
quality
@
PLUGIN_LAYERS
.
register_module
()
class
SparseBox3DKeyPointsGenerator
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
num_learnable_pts
=
0
,
fix_scale
=
None
,
):
super
(
SparseBox3DKeyPointsGenerator
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
num_learnable_pts
=
num_learnable_pts
if
fix_scale
is
None
:
fix_scale
=
((
0.0
,
0.0
,
0.0
),)
self
.
fix_scale
=
nn
.
Parameter
(
torch
.
tensor
(
fix_scale
),
requires_grad
=
False
)
self
.
num_pts
=
len
(
self
.
fix_scale
)
+
num_learnable_pts
if
num_learnable_pts
>
0
:
self
.
learnable_fc
=
Linear
(
self
.
embed_dims
,
num_learnable_pts
*
3
)
def
init_weight
(
self
):
if
self
.
num_learnable_pts
>
0
:
xavier_init
(
self
.
learnable_fc
,
distribution
=
"uniform"
,
bias
=
0.0
)
def
forward
(
self
,
anchor
,
instance_feature
=
None
,
T_cur2temp_list
=
None
,
cur_timestamp
=
None
,
temp_timestamps
=
None
,
):
bs
,
num_anchor
=
anchor
.
shape
[:
2
]
size
=
anchor
[...,
None
,
[
W
,
L
,
H
]].
exp
()
key_points
=
self
.
fix_scale
*
size
if
self
.
num_learnable_pts
>
0
and
instance_feature
is
not
None
:
learnable_scale
=
(
self
.
learnable_fc
(
instance_feature
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_learnable_pts
,
3
)
.
sigmoid
()
-
0.5
)
key_points
=
torch
.
cat
(
[
key_points
,
learnable_scale
*
size
],
dim
=-
2
)
rotation_mat
=
anchor
.
new_zeros
([
bs
,
num_anchor
,
3
,
3
])
rotation_mat
[:,
:,
0
,
0
]
=
anchor
[:,
:,
COS_YAW
]
rotation_mat
[:,
:,
0
,
1
]
=
-
anchor
[:,
:,
SIN_YAW
]
rotation_mat
[:,
:,
1
,
0
]
=
anchor
[:,
:,
SIN_YAW
]
rotation_mat
[:,
:,
1
,
1
]
=
anchor
[:,
:,
COS_YAW
]
rotation_mat
[:,
:,
2
,
2
]
=
1
key_points
=
torch
.
matmul
(
rotation_mat
[:,
:,
None
],
key_points
[...,
None
]
).
squeeze
(
-
1
)
key_points
=
key_points
+
anchor
[...,
None
,
[
X
,
Y
,
Z
]]
if
(
cur_timestamp
is
None
or
temp_timestamps
is
None
or
T_cur2temp_list
is
None
or
len
(
temp_timestamps
)
==
0
):
return
key_points
temp_key_points_list
=
[]
velocity
=
anchor
[...,
VX
:]
for
i
,
t_time
in
enumerate
(
temp_timestamps
):
time_interval
=
cur_timestamp
-
t_time
translation
=
(
velocity
*
time_interval
.
to
(
dtype
=
velocity
.
dtype
)[:,
None
,
None
]
)
temp_key_points
=
key_points
-
translation
[:,
:,
None
]
T_cur2temp
=
T_cur2temp_list
[
i
].
to
(
dtype
=
key_points
.
dtype
)
temp_key_points
=
(
T_cur2temp
[:,
None
,
None
,
:
3
]
@
torch
.
cat
(
[
temp_key_points
,
torch
.
ones_like
(
temp_key_points
[...,
:
1
]),
],
dim
=-
1
,
).
unsqueeze
(
-
1
)
)
temp_key_points
=
temp_key_points
.
squeeze
(
-
1
)
temp_key_points_list
.
append
(
temp_key_points
)
return
key_points
,
temp_key_points_list
@
staticmethod
def
anchor_projection
(
anchor
,
T_src2dst_list
,
src_timestamp
=
None
,
dst_timestamps
=
None
,
time_intervals
=
None
,
):
dst_anchors
=
[]
for
i
in
range
(
len
(
T_src2dst_list
)):
vel
=
anchor
[...,
VX
:]
vel_dim
=
vel
.
shape
[
-
1
]
T_src2dst
=
torch
.
unsqueeze
(
T_src2dst_list
[
i
].
to
(
dtype
=
anchor
.
dtype
),
dim
=
1
)
center
=
anchor
[...,
[
X
,
Y
,
Z
]]
if
time_intervals
is
not
None
:
time_interval
=
time_intervals
[
i
]
elif
src_timestamp
is
not
None
and
dst_timestamps
is
not
None
:
time_interval
=
(
src_timestamp
-
dst_timestamps
[
i
]).
to
(
dtype
=
vel
.
dtype
)
else
:
time_interval
=
None
if
time_interval
is
not
None
:
translation
=
vel
.
transpose
(
0
,
-
1
)
*
time_interval
translation
=
translation
.
transpose
(
0
,
-
1
)
center
=
center
-
translation
center
=
(
torch
.
matmul
(
T_src2dst
[...,
:
3
,
:
3
],
center
[...,
None
]
).
squeeze
(
dim
=-
1
)
+
T_src2dst
[...,
:
3
,
3
]
)
size
=
anchor
[...,
[
W
,
L
,
H
]]
yaw
=
torch
.
matmul
(
T_src2dst
[...,
:
2
,
:
2
],
anchor
[...,
[
COS_YAW
,
SIN_YAW
],
None
],
).
squeeze
(
-
1
)
yaw
=
yaw
[...,
[
1
,
0
]]
vel
=
torch
.
matmul
(
T_src2dst
[...,
:
vel_dim
,
:
vel_dim
],
vel
[...,
None
]
).
squeeze
(
-
1
)
dst_anchor
=
torch
.
cat
([
center
,
size
,
yaw
,
vel
],
dim
=-
1
)
dst_anchors
.
append
(
dst_anchor
)
return
dst_anchors
@
staticmethod
def
distance
(
anchor
):
return
torch
.
norm
(
anchor
[...,
:
2
],
p
=
2
,
dim
=-
1
)
projects/mmdet3d_plugin/models/detection3d/detection3d_head.py
0 → 100644
View file @
afe88104
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
warnings
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
PLUGIN_LAYERS
,
POSITIONAL_ENCODING
,
FEEDFORWARD_NETWORK
,
NORM_LAYERS
,
)
from
mmcv.runner
import
BaseModule
,
force_fp32
from
mmcv.utils
import
build_from_cfg
from
mmdet.core.bbox.builder
import
BBOX_SAMPLERS
from
mmdet.core.bbox.builder
import
BBOX_CODERS
from
mmdet.models
import
HEADS
,
LOSSES
from
mmdet.core
import
reduce_mean
from
..blocks
import
DeformableFeatureAggregation
as
DFG
__all__
=
[
"Sparse4DHead"
]
@
HEADS
.
register_module
()
class
Sparse4DHead
(
BaseModule
):
def
__init__
(
self
,
instance_bank
:
dict
,
anchor_encoder
:
dict
,
graph_model
:
dict
,
norm_layer
:
dict
,
ffn
:
dict
,
deformable_model
:
dict
,
refine_layer
:
dict
,
num_decoder
:
int
=
6
,
num_single_frame_decoder
:
int
=
-
1
,
temp_graph_model
:
dict
=
None
,
loss_cls
:
dict
=
None
,
loss_reg
:
dict
=
None
,
decoder
:
dict
=
None
,
sampler
:
dict
=
None
,
gt_cls_key
:
str
=
"gt_labels_3d"
,
gt_reg_key
:
str
=
"gt_bboxes_3d"
,
gt_id_key
:
str
=
"instance_id"
,
with_instance_id
:
bool
=
True
,
task_prefix
:
str
=
'det'
,
reg_weights
:
List
=
None
,
operation_order
:
Optional
[
List
[
str
]]
=
None
,
cls_threshold_to_reg
:
float
=
-
1
,
dn_loss_weight
:
float
=
5.0
,
decouple_attn
:
bool
=
True
,
init_cfg
:
dict
=
None
,
**
kwargs
,
):
super
(
Sparse4DHead
,
self
).
__init__
(
init_cfg
)
self
.
num_decoder
=
num_decoder
self
.
num_single_frame_decoder
=
num_single_frame_decoder
self
.
gt_cls_key
=
gt_cls_key
self
.
gt_reg_key
=
gt_reg_key
self
.
gt_id_key
=
gt_id_key
self
.
with_instance_id
=
with_instance_id
self
.
task_prefix
=
task_prefix
self
.
cls_threshold_to_reg
=
cls_threshold_to_reg
self
.
dn_loss_weight
=
dn_loss_weight
self
.
decouple_attn
=
decouple_attn
if
reg_weights
is
None
:
self
.
reg_weights
=
[
1.0
]
*
10
else
:
self
.
reg_weights
=
reg_weights
if
operation_order
is
None
:
operation_order
=
[
"temp_gnn"
,
"gnn"
,
"norm"
,
"deformable"
,
"norm"
,
"ffn"
,
"norm"
,
"refine"
,
]
*
num_decoder
# delete the 'gnn' and 'norm' layers in the first transformer blocks
operation_order
=
operation_order
[
3
:]
self
.
operation_order
=
operation_order
# =========== build modules ===========
def
build
(
cfg
,
registry
):
if
cfg
is
None
:
return
None
return
build_from_cfg
(
cfg
,
registry
)
self
.
instance_bank
=
build
(
instance_bank
,
PLUGIN_LAYERS
)
self
.
anchor_encoder
=
build
(
anchor_encoder
,
POSITIONAL_ENCODING
)
self
.
sampler
=
build
(
sampler
,
BBOX_SAMPLERS
)
self
.
decoder
=
build
(
decoder
,
BBOX_CODERS
)
self
.
loss_cls
=
build
(
loss_cls
,
LOSSES
)
self
.
loss_reg
=
build
(
loss_reg
,
LOSSES
)
self
.
op_config_map
=
{
"temp_gnn"
:
[
temp_graph_model
,
ATTENTION
],
"gnn"
:
[
graph_model
,
ATTENTION
],
"norm"
:
[
norm_layer
,
NORM_LAYERS
],
"ffn"
:
[
ffn
,
FEEDFORWARD_NETWORK
],
"deformable"
:
[
deformable_model
,
ATTENTION
],
"refine"
:
[
refine_layer
,
PLUGIN_LAYERS
],
}
self
.
layers
=
nn
.
ModuleList
(
[
build
(
*
self
.
op_config_map
.
get
(
op
,
[
None
,
None
]))
for
op
in
self
.
operation_order
]
)
self
.
embed_dims
=
self
.
instance_bank
.
embed_dims
if
self
.
decouple_attn
:
self
.
fc_before
=
nn
.
Linear
(
self
.
embed_dims
,
self
.
embed_dims
*
2
,
bias
=
False
)
self
.
fc_after
=
nn
.
Linear
(
self
.
embed_dims
*
2
,
self
.
embed_dims
,
bias
=
False
)
else
:
self
.
fc_before
=
nn
.
Identity
()
self
.
fc_after
=
nn
.
Identity
()
def
init_weights
(
self
):
for
i
,
op
in
enumerate
(
self
.
operation_order
):
if
self
.
layers
[
i
]
is
None
:
continue
elif
op
!=
"refine"
:
for
p
in
self
.
layers
[
i
].
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
hasattr
(
m
,
"init_weight"
):
m
.
init_weight
()
def
graph_model
(
self
,
index
,
query
,
key
=
None
,
value
=
None
,
query_pos
=
None
,
key_pos
=
None
,
**
kwargs
,
):
if
self
.
decouple_attn
:
query
=
torch
.
cat
([
query
,
query_pos
],
dim
=-
1
)
if
key
is
not
None
:
key
=
torch
.
cat
([
key
,
key_pos
],
dim
=-
1
)
query_pos
,
key_pos
=
None
,
None
if
value
is
not
None
:
value
=
self
.
fc_before
(
value
)
return
self
.
fc_after
(
self
.
layers
[
index
](
query
,
key
,
value
,
query_pos
=
query_pos
,
key_pos
=
key_pos
,
**
kwargs
,
)
)
def
forward
(
self
,
feature_maps
:
Union
[
torch
.
Tensor
,
List
],
metas
:
dict
,
):
if
isinstance
(
feature_maps
,
torch
.
Tensor
):
feature_maps
=
[
feature_maps
]
batch_size
=
feature_maps
[
0
].
shape
[
0
]
# ========= get instance info ============
if
(
self
.
sampler
.
dn_metas
is
not
None
and
self
.
sampler
.
dn_metas
[
"dn_anchor"
].
shape
[
0
]
!=
batch_size
):
self
.
sampler
.
dn_metas
=
None
(
instance_feature
,
anchor
,
temp_instance_feature
,
temp_anchor
,
time_interval
,
)
=
self
.
instance_bank
.
get
(
batch_size
,
metas
,
dn_metas
=
self
.
sampler
.
dn_metas
)
# ========= prepare for denosing training ============
# 1. get dn metas: noisy-anchors and corresponding GT
# 2. concat learnable instances and noisy instances
# 3. get attention mask
attn_mask
=
None
dn_metas
=
None
temp_dn_reg_target
=
None
if
self
.
training
and
hasattr
(
self
.
sampler
,
"get_dn_anchors"
):
if
self
.
gt_id_key
in
metas
[
"img_metas"
][
0
]:
gt_instance_id
=
[
torch
.
from_numpy
(
x
[
self
.
gt_id_key
]).
cuda
()
for
x
in
metas
[
"img_metas"
]
]
else
:
gt_instance_id
=
None
dn_metas
=
self
.
sampler
.
get_dn_anchors
(
metas
[
self
.
gt_cls_key
],
metas
[
self
.
gt_reg_key
],
gt_instance_id
,
)
if
dn_metas
is
not
None
:
(
dn_anchor
,
dn_reg_target
,
dn_cls_target
,
dn_attn_mask
,
valid_mask
,
dn_id_target
,
)
=
dn_metas
num_dn_anchor
=
dn_anchor
.
shape
[
1
]
if
dn_anchor
.
shape
[
-
1
]
!=
anchor
.
shape
[
-
1
]:
remain_state_dims
=
anchor
.
shape
[
-
1
]
-
dn_anchor
.
shape
[
-
1
]
dn_anchor
=
torch
.
cat
(
[
dn_anchor
,
dn_anchor
.
new_zeros
(
batch_size
,
num_dn_anchor
,
remain_state_dims
),
],
dim
=-
1
,
)
anchor
=
torch
.
cat
([
anchor
,
dn_anchor
],
dim
=
1
)
instance_feature
=
torch
.
cat
(
[
instance_feature
,
instance_feature
.
new_zeros
(
batch_size
,
num_dn_anchor
,
instance_feature
.
shape
[
-
1
]
),
],
dim
=
1
,
)
num_instance
=
instance_feature
.
shape
[
1
]
num_free_instance
=
num_instance
-
num_dn_anchor
attn_mask
=
anchor
.
new_ones
(
(
num_instance
,
num_instance
),
dtype
=
torch
.
bool
)
attn_mask
[:
num_free_instance
,
:
num_free_instance
]
=
False
attn_mask
[
num_free_instance
:,
num_free_instance
:]
=
dn_attn_mask
anchor_embed
=
self
.
anchor_encoder
(
anchor
)
if
temp_anchor
is
not
None
:
temp_anchor_embed
=
self
.
anchor_encoder
(
temp_anchor
)
else
:
temp_anchor_embed
=
None
# =================== forward the layers ====================
prediction
=
[]
classification
=
[]
quality
=
[]
for
i
,
op
in
enumerate
(
self
.
operation_order
):
if
self
.
layers
[
i
]
is
None
:
continue
elif
op
==
"temp_gnn"
:
instance_feature
=
self
.
graph_model
(
i
,
instance_feature
,
temp_instance_feature
,
temp_instance_feature
,
query_pos
=
anchor_embed
,
key_pos
=
temp_anchor_embed
,
attn_mask
=
attn_mask
if
temp_instance_feature
is
None
else
None
,
)
elif
op
==
"gnn"
:
instance_feature
=
self
.
graph_model
(
i
,
instance_feature
,
value
=
instance_feature
,
query_pos
=
anchor_embed
,
attn_mask
=
attn_mask
,
)
elif
op
==
"norm"
or
op
==
"ffn"
:
instance_feature
=
self
.
layers
[
i
](
instance_feature
)
elif
op
==
"deformable"
:
instance_feature
=
self
.
layers
[
i
](
instance_feature
,
anchor
,
anchor_embed
,
feature_maps
,
metas
,
)
elif
op
==
"refine"
:
anchor
,
cls
,
qt
=
self
.
layers
[
i
](
instance_feature
,
anchor
,
anchor_embed
,
time_interval
=
time_interval
,
return_cls
=
True
,
)
prediction
.
append
(
anchor
)
classification
.
append
(
cls
)
quality
.
append
(
qt
)
if
len
(
prediction
)
==
self
.
num_single_frame_decoder
:
instance_feature
,
anchor
=
self
.
instance_bank
.
update
(
instance_feature
,
anchor
,
cls
)
if
(
dn_metas
is
not
None
and
self
.
sampler
.
num_temp_dn_groups
>
0
and
dn_id_target
is
not
None
):
(
instance_feature
,
anchor
,
temp_dn_reg_target
,
temp_dn_cls_target
,
temp_valid_mask
,
dn_id_target
,
)
=
self
.
sampler
.
update_dn
(
instance_feature
,
anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
self
.
instance_bank
.
num_anchor
,
self
.
instance_bank
.
mask
,
)
anchor_embed
=
self
.
anchor_encoder
(
anchor
)
if
(
len
(
prediction
)
>
self
.
num_single_frame_decoder
and
temp_anchor_embed
is
not
None
):
temp_anchor_embed
=
anchor_embed
[
:,
:
self
.
instance_bank
.
num_temp_instances
]
else
:
raise
NotImplementedError
(
f
"
{
op
}
is not supported."
)
output
=
{}
# split predictions of learnable instances and noisy instances
if
dn_metas
is
not
None
:
dn_classification
=
[
x
[:,
num_free_instance
:]
for
x
in
classification
]
classification
=
[
x
[:,
:
num_free_instance
]
for
x
in
classification
]
dn_prediction
=
[
x
[:,
num_free_instance
:]
for
x
in
prediction
]
prediction
=
[
x
[:,
:
num_free_instance
]
for
x
in
prediction
]
quality
=
[
x
[:,
:
num_free_instance
]
if
x
is
not
None
else
None
for
x
in
quality
]
output
.
update
(
{
"dn_prediction"
:
dn_prediction
,
"dn_classification"
:
dn_classification
,
"dn_reg_target"
:
dn_reg_target
,
"dn_cls_target"
:
dn_cls_target
,
"dn_valid_mask"
:
valid_mask
,
}
)
if
temp_dn_reg_target
is
not
None
:
output
.
update
(
{
"temp_dn_reg_target"
:
temp_dn_reg_target
,
"temp_dn_cls_target"
:
temp_dn_cls_target
,
"temp_dn_valid_mask"
:
temp_valid_mask
,
"dn_id_target"
:
dn_id_target
,
}
)
dn_cls_target
=
temp_dn_cls_target
valid_mask
=
temp_valid_mask
dn_instance_feature
=
instance_feature
[:,
num_free_instance
:]
dn_anchor
=
anchor
[:,
num_free_instance
:]
instance_feature
=
instance_feature
[:,
:
num_free_instance
]
anchor_embed
=
anchor_embed
[:,
:
num_free_instance
]
anchor
=
anchor
[:,
:
num_free_instance
]
cls
=
cls
[:,
:
num_free_instance
]
# cache dn_metas for temporal denoising
self
.
sampler
.
cache_dn
(
dn_instance_feature
,
dn_anchor
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
)
output
.
update
(
{
"classification"
:
classification
,
"prediction"
:
prediction
,
"quality"
:
quality
,
"instance_feature"
:
instance_feature
,
"anchor_embed"
:
anchor_embed
,
}
)
# cache current instances for temporal modeling
self
.
instance_bank
.
cache
(
instance_feature
,
anchor
,
cls
,
metas
,
feature_maps
)
if
self
.
with_instance_id
:
instance_id
=
self
.
instance_bank
.
get_instance_id
(
cls
,
anchor
,
self
.
decoder
.
score_threshold
)
output
[
"instance_id"
]
=
instance_id
return
output
@
force_fp32
(
apply_to
=
(
"model_outs"
))
def
loss
(
self
,
model_outs
,
data
,
feature_maps
=
None
):
# ===================== prediction losses ======================
cls_scores
=
model_outs
[
"classification"
]
reg_preds
=
model_outs
[
"prediction"
]
quality
=
model_outs
[
"quality"
]
output
=
{}
for
decoder_idx
,
(
cls
,
reg
,
qt
)
in
enumerate
(
zip
(
cls_scores
,
reg_preds
,
quality
)
):
reg
=
reg
[...,
:
len
(
self
.
reg_weights
)]
cls_target
,
reg_target
,
reg_weights
=
self
.
sampler
.
sample
(
cls
,
reg
,
data
[
self
.
gt_cls_key
],
data
[
self
.
gt_reg_key
],
)
reg_target
=
reg_target
[...,
:
len
(
self
.
reg_weights
)]
reg_target_full
=
reg_target
.
clone
()
mask
=
torch
.
logical_not
(
torch
.
all
(
reg_target
==
0
,
dim
=-
1
))
mask_valid
=
mask
.
clone
()
num_pos
=
max
(
reduce_mean
(
torch
.
sum
(
mask
).
to
(
dtype
=
reg
.
dtype
)),
1.0
)
if
self
.
cls_threshold_to_reg
>
0
:
threshold
=
self
.
cls_threshold_to_reg
mask
=
torch
.
logical_and
(
mask
,
cls
.
max
(
dim
=-
1
).
values
.
sigmoid
()
>
threshold
)
cls
=
cls
.
flatten
(
end_dim
=
1
)
cls_target
=
cls_target
.
flatten
(
end_dim
=
1
)
cls_loss
=
self
.
loss_cls
(
cls
,
cls_target
,
avg_factor
=
num_pos
)
mask
=
mask
.
reshape
(
-
1
)
reg_weights
=
reg_weights
*
reg
.
new_tensor
(
self
.
reg_weights
)
reg_target
=
reg_target
.
flatten
(
end_dim
=
1
)[
mask
]
reg
=
reg
.
flatten
(
end_dim
=
1
)[
mask
]
reg_weights
=
reg_weights
.
flatten
(
end_dim
=
1
)[
mask
]
reg_target
=
torch
.
where
(
reg_target
.
isnan
(),
reg
.
new_tensor
(
0.0
),
reg_target
)
cls_target
=
cls_target
[
mask
]
if
qt
is
not
None
:
qt
=
qt
.
flatten
(
end_dim
=
1
)[
mask
]
reg_loss
=
self
.
loss_reg
(
reg
,
reg_target
,
weight
=
reg_weights
,
avg_factor
=
num_pos
,
prefix
=
f
"
{
self
.
task_prefix
}
_"
,
suffix
=
f
"_
{
decoder_idx
}
"
,
quality
=
qt
,
cls_target
=
cls_target
,
)
output
[
f
"
{
self
.
task_prefix
}
_loss_cls_
{
decoder_idx
}
"
]
=
cls_loss
output
.
update
(
reg_loss
)
if
"dn_prediction"
not
in
model_outs
:
return
output
# ===================== denoising losses ======================
dn_cls_scores
=
model_outs
[
"dn_classification"
]
dn_reg_preds
=
model_outs
[
"dn_prediction"
]
(
dn_valid_mask
,
dn_cls_target
,
dn_reg_target
,
dn_pos_mask
,
reg_weights
,
num_dn_pos
,
)
=
self
.
prepare_for_dn_loss
(
model_outs
)
for
decoder_idx
,
(
cls
,
reg
)
in
enumerate
(
zip
(
dn_cls_scores
,
dn_reg_preds
)
):
if
(
"temp_dn_valid_mask"
in
model_outs
and
decoder_idx
==
self
.
num_single_frame_decoder
):
(
dn_valid_mask
,
dn_cls_target
,
dn_reg_target
,
dn_pos_mask
,
reg_weights
,
num_dn_pos
,
)
=
self
.
prepare_for_dn_loss
(
model_outs
,
prefix
=
"temp_"
)
cls_loss
=
self
.
loss_cls
(
cls
.
flatten
(
end_dim
=
1
)[
dn_valid_mask
],
dn_cls_target
,
avg_factor
=
num_dn_pos
,
)
reg_loss
=
self
.
loss_reg
(
reg
.
flatten
(
end_dim
=
1
)[
dn_valid_mask
][
dn_pos_mask
][
...,
:
len
(
self
.
reg_weights
)
],
dn_reg_target
,
avg_factor
=
num_dn_pos
,
weight
=
reg_weights
,
prefix
=
f
"
{
self
.
task_prefix
}
_"
,
suffix
=
f
"_dn_
{
decoder_idx
}
"
,
)
output
[
f
"
{
self
.
task_prefix
}
_loss_cls_dn_
{
decoder_idx
}
"
]
=
cls_loss
output
.
update
(
reg_loss
)
return
output
def
prepare_for_dn_loss
(
self
,
model_outs
,
prefix
=
""
):
dn_valid_mask
=
model_outs
[
f
"
{
prefix
}
dn_valid_mask"
].
flatten
(
end_dim
=
1
)
dn_cls_target
=
model_outs
[
f
"
{
prefix
}
dn_cls_target"
].
flatten
(
end_dim
=
1
)[
dn_valid_mask
]
dn_reg_target
=
model_outs
[
f
"
{
prefix
}
dn_reg_target"
].
flatten
(
end_dim
=
1
)[
dn_valid_mask
][...,
:
len
(
self
.
reg_weights
)]
dn_pos_mask
=
dn_cls_target
>=
0
dn_reg_target
=
dn_reg_target
[
dn_pos_mask
]
reg_weights
=
dn_reg_target
.
new_tensor
(
self
.
reg_weights
)[
None
].
tile
(
dn_reg_target
.
shape
[
0
],
1
)
num_dn_pos
=
max
(
reduce_mean
(
torch
.
sum
(
dn_valid_mask
).
to
(
dtype
=
reg_weights
.
dtype
)),
1.0
,
)
return
(
dn_valid_mask
,
dn_cls_target
,
dn_reg_target
,
dn_pos_mask
,
reg_weights
,
num_dn_pos
,
)
@
force_fp32
(
apply_to
=
(
"model_outs"
))
def
post_process
(
self
,
model_outs
,
output_idx
=-
1
):
return
self
.
decoder
.
decode
(
model_outs
[
"classification"
],
model_outs
[
"prediction"
],
model_outs
.
get
(
"instance_id"
),
model_outs
.
get
(
"quality"
),
output_idx
=
output_idx
,
)
projects/mmdet3d_plugin/models/detection3d/losses.py
0 → 100644
View file @
afe88104
import
torch
import
torch.nn
as
nn
from
mmcv.utils
import
build_from_cfg
from
mmdet.models.builder
import
LOSSES
from
projects.mmdet3d_plugin.core.box3d
import
*
@
LOSSES
.
register_module
()
class
SparseBox3DLoss
(
nn
.
Module
):
def
__init__
(
self
,
loss_box
,
loss_centerness
=
None
,
loss_yawness
=
None
,
cls_allow_reverse
=
None
,
):
super
().
__init__
()
def
build
(
cfg
,
registry
):
if
cfg
is
None
:
return
None
return
build_from_cfg
(
cfg
,
registry
)
self
.
loss_box
=
build
(
loss_box
,
LOSSES
)
self
.
loss_cns
=
build
(
loss_centerness
,
LOSSES
)
self
.
loss_yns
=
build
(
loss_yawness
,
LOSSES
)
self
.
cls_allow_reverse
=
cls_allow_reverse
def
forward
(
self
,
box
,
box_target
,
weight
=
None
,
avg_factor
=
None
,
prefix
=
""
,
suffix
=
""
,
quality
=
None
,
cls_target
=
None
,
**
kwargs
,
):
# Some categories do not distinguish between positive and negative
# directions. For example, barrier in nuScenes dataset.
if
self
.
cls_allow_reverse
is
not
None
and
cls_target
is
not
None
:
if_reverse
=
(
torch
.
nn
.
functional
.
cosine_similarity
(
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
box
[...,
[
SIN_YAW
,
COS_YAW
]],
dim
=-
1
,
)
<
0
)
if_reverse
=
(
torch
.
isin
(
cls_target
,
cls_target
.
new_tensor
(
self
.
cls_allow_reverse
)
)
&
if_reverse
)
box_target
[...,
[
SIN_YAW
,
COS_YAW
]]
=
torch
.
where
(
if_reverse
[...,
None
],
-
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
)
output
=
{}
box_loss
=
self
.
loss_box
(
box
,
box_target
,
weight
=
weight
,
avg_factor
=
avg_factor
)
output
[
f
"
{
prefix
}
loss_box
{
suffix
}
"
]
=
box_loss
if
quality
is
not
None
:
cns
=
quality
[...,
CNS
]
yns
=
quality
[...,
YNS
].
sigmoid
()
cns_target
=
torch
.
norm
(
box_target
[...,
[
X
,
Y
,
Z
]]
-
box
[...,
[
X
,
Y
,
Z
]],
p
=
2
,
dim
=-
1
)
cns_target
=
torch
.
exp
(
-
cns_target
)
cns_loss
=
self
.
loss_cns
(
cns
,
cns_target
,
avg_factor
=
avg_factor
)
output
[
f
"
{
prefix
}
loss_cns
{
suffix
}
"
]
=
cns_loss
yns_target
=
(
torch
.
nn
.
functional
.
cosine_similarity
(
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
box
[...,
[
SIN_YAW
,
COS_YAW
]],
dim
=-
1
,
)
>
0
)
yns_target
=
yns_target
.
float
()
yns_loss
=
self
.
loss_yns
(
yns
,
yns_target
,
avg_factor
=
avg_factor
)
output
[
f
"
{
prefix
}
loss_yns
{
suffix
}
"
]
=
yns_loss
return
output
projects/mmdet3d_plugin/models/detection3d/target.py
0 → 100644
View file @
afe88104
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
scipy.optimize
import
linear_sum_assignment
from
mmdet.core.bbox.builder
import
BBOX_SAMPLERS
from
projects.mmdet3d_plugin.core.box3d
import
*
from
..base_target
import
BaseTargetWithDenoising
__all__
=
[
"SparseBox3DTarget"
]
@
BBOX_SAMPLERS
.
register_module
()
class
SparseBox3DTarget
(
BaseTargetWithDenoising
):
def
__init__
(
self
,
cls_weight
=
2.0
,
alpha
=
0.25
,
gamma
=
2
,
eps
=
1e-12
,
box_weight
=
0.25
,
reg_weights
=
None
,
cls_wise_reg_weights
=
None
,
num_dn_groups
=
0
,
dn_noise_scale
=
0.5
,
max_dn_gt
=
32
,
add_neg_dn
=
True
,
num_temp_dn_groups
=
0
,
):
super
(
SparseBox3DTarget
,
self
).
__init__
(
num_dn_groups
,
num_temp_dn_groups
)
self
.
cls_weight
=
cls_weight
self
.
box_weight
=
box_weight
self
.
alpha
=
alpha
self
.
gamma
=
gamma
self
.
eps
=
eps
self
.
reg_weights
=
reg_weights
if
self
.
reg_weights
is
None
:
self
.
reg_weights
=
[
1.0
]
*
8
+
[
0.0
]
*
2
self
.
cls_wise_reg_weights
=
cls_wise_reg_weights
self
.
dn_noise_scale
=
dn_noise_scale
self
.
max_dn_gt
=
max_dn_gt
self
.
add_neg_dn
=
add_neg_dn
def
encode_reg_target
(
self
,
box_target
,
device
=
None
):
outputs
=
[]
for
box
in
box_target
:
output
=
torch
.
cat
(
[
box
[...,
[
X
,
Y
,
Z
]],
box
[...,
[
W
,
L
,
H
]].
log
(),
torch
.
sin
(
box
[...,
YAW
]).
unsqueeze
(
-
1
),
torch
.
cos
(
box
[...,
YAW
]).
unsqueeze
(
-
1
),
box
[...,
YAW
+
1
:],
],
dim
=-
1
,
)
if
device
is
not
None
:
output
=
output
.
to
(
device
=
device
)
outputs
.
append
(
output
)
return
outputs
def
sample
(
self
,
cls_pred
,
box_pred
,
cls_target
,
box_target
,
):
bs
,
num_pred
,
num_cls
=
cls_pred
.
shape
cls_cost
=
self
.
_cls_cost
(
cls_pred
,
cls_target
)
box_target
=
self
.
encode_reg_target
(
box_target
,
box_pred
.
device
)
instance_reg_weights
=
[]
for
i
in
range
(
len
(
box_target
)):
weights
=
torch
.
logical_not
(
box_target
[
i
].
isnan
()).
to
(
dtype
=
box_target
[
i
].
dtype
)
if
self
.
cls_wise_reg_weights
is
not
None
:
for
cls
,
weight
in
self
.
cls_wise_reg_weights
.
items
():
weights
=
torch
.
where
(
(
cls_target
[
i
]
==
cls
)[:,
None
],
weights
.
new_tensor
(
weight
),
weights
,
)
instance_reg_weights
.
append
(
weights
)
box_cost
=
self
.
_box_cost
(
box_pred
,
box_target
,
instance_reg_weights
)
indices
=
[]
for
i
in
range
(
bs
):
if
cls_cost
[
i
]
is
not
None
and
box_cost
[
i
]
is
not
None
:
cost
=
(
cls_cost
[
i
]
+
box_cost
[
i
]).
detach
().
cpu
().
numpy
()
cost
=
np
.
where
(
np
.
isneginf
(
cost
)
|
np
.
isnan
(
cost
),
1e8
,
cost
)
assign
=
linear_sum_assignment
(
cost
)
indices
.
append
(
[
cls_pred
.
new_tensor
(
x
,
dtype
=
torch
.
int64
)
for
x
in
assign
]
)
else
:
indices
.
append
([
None
,
None
])
output_cls_target
=
(
cls_target
[
0
].
new_ones
([
bs
,
num_pred
],
dtype
=
torch
.
long
)
*
num_cls
)
output_box_target
=
box_pred
.
new_zeros
(
box_pred
.
shape
)
output_reg_weights
=
box_pred
.
new_zeros
(
box_pred
.
shape
)
for
i
,
(
pred_idx
,
target_idx
)
in
enumerate
(
indices
):
if
len
(
cls_target
[
i
])
==
0
:
continue
output_cls_target
[
i
,
pred_idx
]
=
cls_target
[
i
][
target_idx
]
output_box_target
[
i
,
pred_idx
]
=
box_target
[
i
][
target_idx
]
output_reg_weights
[
i
,
pred_idx
]
=
instance_reg_weights
[
i
][
target_idx
]
self
.
indices
=
indices
return
output_cls_target
,
output_box_target
,
output_reg_weights
def
_cls_cost
(
self
,
cls_pred
,
cls_target
):
bs
=
cls_pred
.
shape
[
0
]
cls_pred
=
cls_pred
.
sigmoid
()
cost
=
[]
for
i
in
range
(
bs
):
if
len
(
cls_target
[
i
])
>
0
:
neg_cost
=
(
-
(
1
-
cls_pred
[
i
]
+
self
.
eps
).
log
()
*
(
1
-
self
.
alpha
)
*
cls_pred
[
i
].
pow
(
self
.
gamma
)
)
pos_cost
=
(
-
(
cls_pred
[
i
]
+
self
.
eps
).
log
()
*
self
.
alpha
*
(
1
-
cls_pred
[
i
]).
pow
(
self
.
gamma
)
)
cost
.
append
(
(
pos_cost
[:,
cls_target
[
i
]]
-
neg_cost
[:,
cls_target
[
i
]])
*
self
.
cls_weight
)
else
:
cost
.
append
(
None
)
return
cost
def
_box_cost
(
self
,
box_pred
,
box_target
,
instance_reg_weights
):
bs
=
box_pred
.
shape
[
0
]
cost
=
[]
for
i
in
range
(
bs
):
if
len
(
box_target
[
i
])
>
0
:
cost
.
append
(
torch
.
sum
(
torch
.
abs
(
box_pred
[
i
,
:,
None
]
-
box_target
[
i
][
None
])
*
instance_reg_weights
[
i
][
None
]
*
box_pred
.
new_tensor
(
self
.
reg_weights
),
dim
=-
1
,
)
*
self
.
box_weight
)
else
:
cost
.
append
(
None
)
return
cost
def
get_dn_anchors
(
self
,
cls_target
,
box_target
,
gt_instance_id
=
None
):
if
self
.
num_dn_groups
<=
0
:
return
None
if
self
.
num_temp_dn_groups
<=
0
:
gt_instance_id
=
None
if
self
.
max_dn_gt
>
0
:
cls_target
=
[
x
[:
self
.
max_dn_gt
]
for
x
in
cls_target
]
box_target
=
[
x
[:
self
.
max_dn_gt
]
for
x
in
box_target
]
if
gt_instance_id
is
not
None
:
gt_instance_id
=
[
x
[:
self
.
max_dn_gt
]
for
x
in
gt_instance_id
]
max_dn_gt
=
max
([
len
(
x
)
for
x
in
cls_target
])
if
max_dn_gt
==
0
:
return
None
cls_target
=
torch
.
stack
(
[
F
.
pad
(
x
,
(
0
,
max_dn_gt
-
x
.
shape
[
0
]),
value
=-
1
)
for
x
in
cls_target
]
)
box_target
=
self
.
encode_reg_target
(
box_target
,
cls_target
.
device
)
box_target
=
torch
.
stack
(
[
F
.
pad
(
x
,
(
0
,
0
,
0
,
max_dn_gt
-
x
.
shape
[
0
]))
for
x
in
box_target
]
)
box_target
=
torch
.
where
(
cls_target
[...,
None
]
==
-
1
,
box_target
.
new_tensor
(
0
),
box_target
)
if
gt_instance_id
is
not
None
:
gt_instance_id
=
torch
.
stack
(
[
F
.
pad
(
x
,
(
0
,
max_dn_gt
-
x
.
shape
[
0
]),
value
=-
1
)
for
x
in
gt_instance_id
]
)
bs
,
num_gt
,
state_dims
=
box_target
.
shape
if
self
.
num_dn_groups
>
1
:
cls_target
=
cls_target
.
tile
(
self
.
num_dn_groups
,
1
)
box_target
=
box_target
.
tile
(
self
.
num_dn_groups
,
1
,
1
)
if
gt_instance_id
is
not
None
:
gt_instance_id
=
gt_instance_id
.
tile
(
self
.
num_dn_groups
,
1
)
noise
=
torch
.
rand_like
(
box_target
)
*
2
-
1
noise
*=
box_target
.
new_tensor
(
self
.
dn_noise_scale
)
dn_anchor
=
box_target
+
noise
if
self
.
add_neg_dn
:
noise_neg
=
torch
.
rand_like
(
box_target
)
+
1
flag
=
torch
.
where
(
torch
.
rand_like
(
box_target
)
>
0.5
,
noise_neg
.
new_tensor
(
1
),
noise_neg
.
new_tensor
(
-
1
),
)
noise_neg
*=
flag
noise_neg
*=
box_target
.
new_tensor
(
self
.
dn_noise_scale
)
dn_anchor
=
torch
.
cat
([
dn_anchor
,
box_target
+
noise_neg
],
dim
=
1
)
num_gt
*=
2
box_cost
=
self
.
_box_cost
(
dn_anchor
,
box_target
,
torch
.
ones_like
(
box_target
)
)
dn_box_target
=
torch
.
zeros_like
(
dn_anchor
)
dn_cls_target
=
-
torch
.
ones_like
(
cls_target
)
*
3
if
gt_instance_id
is
not
None
:
dn_id_target
=
-
torch
.
ones_like
(
gt_instance_id
)
if
self
.
add_neg_dn
:
dn_cls_target
=
torch
.
cat
([
dn_cls_target
,
dn_cls_target
],
dim
=
1
)
if
gt_instance_id
is
not
None
:
dn_id_target
=
torch
.
cat
([
dn_id_target
,
dn_id_target
],
dim
=
1
)
for
i
in
range
(
dn_anchor
.
shape
[
0
]):
cost
=
box_cost
[
i
].
cpu
().
numpy
()
anchor_idx
,
gt_idx
=
linear_sum_assignment
(
cost
)
anchor_idx
=
dn_anchor
.
new_tensor
(
anchor_idx
,
dtype
=
torch
.
int64
)
gt_idx
=
dn_anchor
.
new_tensor
(
gt_idx
,
dtype
=
torch
.
int64
)
dn_box_target
[
i
,
anchor_idx
]
=
box_target
[
i
,
gt_idx
]
dn_cls_target
[
i
,
anchor_idx
]
=
cls_target
[
i
,
gt_idx
]
if
gt_instance_id
is
not
None
:
dn_id_target
[
i
,
anchor_idx
]
=
gt_instance_id
[
i
,
gt_idx
]
dn_anchor
=
(
dn_anchor
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
,
state_dims
)
.
permute
(
1
,
0
,
2
,
3
)
.
flatten
(
1
,
2
)
)
dn_box_target
=
(
dn_box_target
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
,
state_dims
)
.
permute
(
1
,
0
,
2
,
3
)
.
flatten
(
1
,
2
)
)
dn_cls_target
=
(
dn_cls_target
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
)
.
permute
(
1
,
0
,
2
)
.
flatten
(
1
)
)
if
gt_instance_id
is
not
None
:
dn_id_target
=
(
dn_id_target
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
)
.
permute
(
1
,
0
,
2
)
.
flatten
(
1
)
)
else
:
dn_id_target
=
None
valid_mask
=
dn_cls_target
>=
0
if
self
.
add_neg_dn
:
cls_target
=
(
torch
.
cat
([
cls_target
,
cls_target
],
dim
=
1
)
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
)
.
permute
(
1
,
0
,
2
)
.
flatten
(
1
)
)
valid_mask
=
torch
.
logical_or
(
valid_mask
,
((
cls_target
>=
0
)
&
(
dn_cls_target
==
-
3
))
)
# valid denotes the items is not from pad.
attn_mask
=
dn_box_target
.
new_ones
(
num_gt
*
self
.
num_dn_groups
,
num_gt
*
self
.
num_dn_groups
)
for
i
in
range
(
self
.
num_dn_groups
):
start
=
num_gt
*
i
end
=
start
+
num_gt
attn_mask
[
start
:
end
,
start
:
end
]
=
0
attn_mask
=
attn_mask
==
1
dn_cls_target
=
dn_cls_target
.
long
()
return
(
dn_anchor
,
dn_box_target
,
dn_cls_target
,
attn_mask
,
valid_mask
,
dn_id_target
,
)
def
update_dn
(
self
,
instance_feature
,
anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
num_noraml_anchor
,
temporal_valid_mask
,
):
bs
,
num_anchor
=
instance_feature
.
shape
[:
2
]
if
temporal_valid_mask
is
None
:
self
.
dn_metas
=
None
if
self
.
dn_metas
is
None
or
num_noraml_anchor
>=
num_anchor
:
return
(
instance_feature
,
anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
)
# split instance_feature and anchor into non-dn and dn
num_dn
=
num_anchor
-
num_noraml_anchor
dn_instance_feature
=
instance_feature
[:,
-
num_dn
:]
dn_anchor
=
anchor
[:,
-
num_dn
:]
instance_feature
=
instance_feature
[:,
:
num_noraml_anchor
]
anchor
=
anchor
[:,
:
num_noraml_anchor
]
# reshape all dn metas from (bs,num_all_dn,xxx)
# to (bs, dn_group, num_dn_per_group, xxx)
num_dn_groups
=
self
.
num_dn_groups
num_dn
=
num_dn
//
num_dn_groups
dn_feat
=
dn_instance_feature
.
reshape
(
bs
,
num_dn_groups
,
num_dn
,
-
1
)
dn_anchor
=
dn_anchor
.
reshape
(
bs
,
num_dn_groups
,
num_dn
,
-
1
)
dn_reg_target
=
dn_reg_target
.
reshape
(
bs
,
num_dn_groups
,
num_dn
,
-
1
)
dn_cls_target
=
dn_cls_target
.
reshape
(
bs
,
num_dn_groups
,
num_dn
)
valid_mask
=
valid_mask
.
reshape
(
bs
,
num_dn_groups
,
num_dn
)
if
dn_id_target
is
not
None
:
dn_id
=
dn_id_target
.
reshape
(
bs
,
num_dn_groups
,
num_dn
)
# update temp_dn_metas by instance_id
temp_dn_feat
=
self
.
dn_metas
[
"dn_instance_feature"
]
_
,
num_temp_dn_groups
,
num_temp_dn
=
temp_dn_feat
.
shape
[:
3
]
temp_dn_id
=
self
.
dn_metas
[
"dn_id_target"
]
# bs, num_temp_dn_groups, num_temp_dn, num_dn
match
=
temp_dn_id
[...,
None
]
==
dn_id
[:,
:
num_temp_dn_groups
,
None
]
temp_reg_target
=
(
match
[...,
None
]
*
dn_reg_target
[:,
:
num_temp_dn_groups
,
None
]
).
sum
(
dim
=
3
)
temp_cls_target
=
torch
.
where
(
torch
.
all
(
torch
.
logical_not
(
match
),
dim
=-
1
),
self
.
dn_metas
[
"dn_cls_target"
].
new_tensor
(
-
1
),
self
.
dn_metas
[
"dn_cls_target"
],
)
temp_valid_mask
=
self
.
dn_metas
[
"valid_mask"
]
temp_dn_anchor
=
self
.
dn_metas
[
"dn_anchor"
]
# handle the misalignment the length of temp_dn to dn caused by the
# change of num_gt, then concat the temp_dn and dn
temp_dn_metas
=
[
temp_dn_feat
,
temp_dn_anchor
,
temp_reg_target
,
temp_cls_target
,
temp_valid_mask
,
temp_dn_id
,
]
dn_metas
=
[
dn_feat
,
dn_anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id
,
]
output
=
[]
for
i
,
(
temp_meta
,
meta
)
in
enumerate
(
zip
(
temp_dn_metas
,
dn_metas
)):
if
num_temp_dn
<
num_dn
:
pad
=
(
0
,
num_dn
-
num_temp_dn
)
if
temp_meta
.
dim
()
==
4
:
pad
=
(
0
,
0
)
+
pad
else
:
assert
temp_meta
.
dim
()
==
3
temp_meta
=
F
.
pad
(
temp_meta
,
pad
,
value
=
0
)
else
:
temp_meta
=
temp_meta
[:,
:,
:
num_dn
]
mask
=
temporal_valid_mask
[:,
None
,
None
]
if
meta
.
dim
()
==
4
:
mask
=
mask
.
unsqueeze
(
dim
=-
1
)
temp_meta
=
torch
.
where
(
mask
,
temp_meta
,
meta
[:,
:
num_temp_dn_groups
]
)
meta
=
torch
.
cat
([
temp_meta
,
meta
[:,
num_temp_dn_groups
:]],
dim
=
1
)
meta
=
meta
.
flatten
(
1
,
2
)
output
.
append
(
meta
)
output
[
0
]
=
torch
.
cat
([
instance_feature
,
output
[
0
]],
dim
=
1
)
output
[
1
]
=
torch
.
cat
([
anchor
,
output
[
1
]],
dim
=
1
)
return
output
def
cache_dn
(
self
,
dn_instance_feature
,
dn_anchor
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
):
if
self
.
num_temp_dn_groups
<
0
:
return
num_dn_groups
=
self
.
num_dn_groups
bs
,
num_dn
=
dn_instance_feature
.
shape
[:
2
]
num_temp_dn
=
num_dn
//
num_dn_groups
temp_group_mask
=
(
torch
.
randperm
(
num_dn_groups
)
<
self
.
num_temp_dn_groups
)
temp_group_mask
=
temp_group_mask
.
to
(
device
=
dn_anchor
.
device
)
dn_instance_feature
=
dn_instance_feature
.
detach
().
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
,
-
1
)[:,
temp_group_mask
]
dn_anchor
=
dn_anchor
.
detach
().
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
,
-
1
)[:,
temp_group_mask
]
dn_cls_target
=
dn_cls_target
.
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
)[
:,
temp_group_mask
]
valid_mask
=
valid_mask
.
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
)[
:,
temp_group_mask
]
if
dn_id_target
is
not
None
:
dn_id_target
=
dn_id_target
.
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
)[:,
temp_group_mask
]
self
.
dn_metas
=
dict
(
dn_instance_feature
=
dn_instance_feature
,
dn_anchor
=
dn_anchor
,
dn_cls_target
=
dn_cls_target
,
valid_mask
=
valid_mask
,
dn_id_target
=
dn_id_target
,
)
projects/mmdet3d_plugin/models/grid_mask.py
0 → 100644
View file @
afe88104
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
PIL
import
Image
class
Grid
(
object
):
def
__init__
(
self
,
use_h
,
use_w
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
0
,
prob
=
1.0
):
self
.
use_h
=
use_h
self
.
use_w
=
use_w
self
.
rotate
=
rotate
self
.
offset
=
offset
self
.
ratio
=
ratio
self
.
mode
=
mode
self
.
st_prob
=
prob
self
.
prob
=
prob
def
set_prob
(
self
,
epoch
,
max_epoch
):
self
.
prob
=
self
.
st_prob
*
epoch
/
max_epoch
def
__call__
(
self
,
img
,
label
):
if
np
.
random
.
rand
()
>
self
.
prob
:
return
img
,
label
h
=
img
.
size
(
1
)
w
=
img
.
size
(
2
)
self
.
d1
=
2
self
.
d2
=
min
(
h
,
w
)
hh
=
int
(
1.5
*
h
)
ww
=
int
(
1.5
*
w
)
d
=
np
.
random
.
randint
(
self
.
d1
,
self
.
d2
)
if
self
.
ratio
==
1
:
self
.
l
=
np
.
random
.
randint
(
1
,
d
)
else
:
self
.
l
=
min
(
max
(
int
(
d
*
self
.
ratio
+
0.5
),
1
),
d
-
1
)
mask
=
np
.
ones
((
hh
,
ww
),
np
.
float32
)
st_h
=
np
.
random
.
randint
(
d
)
st_w
=
np
.
random
.
randint
(
d
)
if
self
.
use_h
:
for
i
in
range
(
hh
//
d
):
s
=
d
*
i
+
st_h
t
=
min
(
s
+
self
.
l
,
hh
)
mask
[
s
:
t
,
:]
*=
0
if
self
.
use_w
:
for
i
in
range
(
ww
//
d
):
s
=
d
*
i
+
st_w
t
=
min
(
s
+
self
.
l
,
ww
)
mask
[:,
s
:
t
]
*=
0
r
=
np
.
random
.
randint
(
self
.
rotate
)
mask
=
Image
.
fromarray
(
np
.
uint8
(
mask
))
mask
=
mask
.
rotate
(
r
)
mask
=
np
.
asarray
(
mask
)
mask
=
mask
[
(
hh
-
h
)
//
2
:
(
hh
-
h
)
//
2
+
h
,
(
ww
-
w
)
//
2
:
(
ww
-
w
)
//
2
+
w
,
]
mask
=
torch
.
from_numpy
(
mask
).
float
()
if
self
.
mode
==
1
:
mask
=
1
-
mask
mask
=
mask
.
expand_as
(
img
)
if
self
.
offset
:
offset
=
torch
.
from_numpy
(
2
*
(
np
.
random
.
rand
(
h
,
w
)
-
0.5
)).
float
()
offset
=
(
1
-
mask
)
*
offset
img
=
img
*
mask
+
offset
else
:
img
=
img
*
mask
return
img
,
label
class
GridMask
(
nn
.
Module
):
def
__init__
(
self
,
use_h
,
use_w
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
0
,
prob
=
1.0
):
super
(
GridMask
,
self
).
__init__
()
self
.
use_h
=
use_h
self
.
use_w
=
use_w
self
.
rotate
=
rotate
self
.
offset
=
offset
self
.
ratio
=
ratio
self
.
mode
=
mode
self
.
st_prob
=
prob
self
.
prob
=
prob
def
set_prob
(
self
,
epoch
,
max_epoch
):
self
.
prob
=
self
.
st_prob
*
epoch
/
max_epoch
# + 1.#0.5
def
forward
(
self
,
x
):
if
np
.
random
.
rand
()
>
self
.
prob
or
not
self
.
training
:
return
x
n
,
c
,
h
,
w
=
x
.
size
()
x
=
x
.
view
(
-
1
,
h
,
w
)
hh
=
int
(
1.5
*
h
)
ww
=
int
(
1.5
*
w
)
d
=
np
.
random
.
randint
(
2
,
h
)
self
.
l
=
min
(
max
(
int
(
d
*
self
.
ratio
+
0.5
),
1
),
d
-
1
)
mask
=
np
.
ones
((
hh
,
ww
),
np
.
float32
)
st_h
=
np
.
random
.
randint
(
d
)
st_w
=
np
.
random
.
randint
(
d
)
if
self
.
use_h
:
for
i
in
range
(
hh
//
d
):
s
=
d
*
i
+
st_h
t
=
min
(
s
+
self
.
l
,
hh
)
mask
[
s
:
t
,
:]
*=
0
if
self
.
use_w
:
for
i
in
range
(
ww
//
d
):
s
=
d
*
i
+
st_w
t
=
min
(
s
+
self
.
l
,
ww
)
mask
[:,
s
:
t
]
*=
0
r
=
np
.
random
.
randint
(
self
.
rotate
)
mask
=
Image
.
fromarray
(
np
.
uint8
(
mask
))
mask
=
mask
.
rotate
(
r
)
mask
=
np
.
asarray
(
mask
)
mask
=
mask
[
(
hh
-
h
)
//
2
:
(
hh
-
h
)
//
2
+
h
,
(
ww
-
w
)
//
2
:
(
ww
-
w
)
//
2
+
w
,
]
mask
=
torch
.
from_numpy
(
mask
.
copy
()).
float
().
cuda
()
if
self
.
mode
==
1
:
mask
=
1
-
mask
mask
=
mask
.
expand_as
(
x
)
if
self
.
offset
:
offset
=
(
torch
.
from_numpy
(
2
*
(
np
.
random
.
rand
(
h
,
w
)
-
0.5
))
.
float
()
.
cuda
()
)
x
=
x
*
mask
+
offset
*
(
1
-
mask
)
else
:
x
=
x
*
mask
return
x
.
view
(
n
,
c
,
h
,
w
)
projects/mmdet3d_plugin/models/instance_bank.py
0 → 100644
View file @
afe88104
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
mmcv.utils
import
build_from_cfg
from
mmcv.cnn.bricks.registry
import
PLUGIN_LAYERS
__all__
=
[
"InstanceBank"
]
def
topk
(
confidence
,
k
,
*
inputs
):
bs
,
N
=
confidence
.
shape
[:
2
]
confidence
,
indices
=
torch
.
topk
(
confidence
,
k
,
dim
=
1
)
indices
=
(
indices
+
torch
.
arange
(
bs
,
device
=
indices
.
device
)[:,
None
]
*
N
).
reshape
(
-
1
)
outputs
=
[]
for
input
in
inputs
:
outputs
.
append
(
input
.
flatten
(
end_dim
=
1
)[
indices
].
reshape
(
bs
,
k
,
-
1
))
return
confidence
,
outputs
@
PLUGIN_LAYERS
.
register_module
()
class
InstanceBank
(
nn
.
Module
):
def
__init__
(
self
,
num_anchor
,
embed_dims
,
anchor
,
anchor_handler
=
None
,
num_temp_instances
=
0
,
default_time_interval
=
0.5
,
confidence_decay
=
0.6
,
anchor_grad
=
True
,
feat_grad
=
True
,
max_time_interval
=
2
,
):
super
(
InstanceBank
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
num_temp_instances
=
num_temp_instances
self
.
default_time_interval
=
default_time_interval
self
.
confidence_decay
=
confidence_decay
self
.
max_time_interval
=
max_time_interval
if
anchor_handler
is
not
None
:
anchor_handler
=
build_from_cfg
(
anchor_handler
,
PLUGIN_LAYERS
)
assert
hasattr
(
anchor_handler
,
"anchor_projection"
)
self
.
anchor_handler
=
anchor_handler
if
isinstance
(
anchor
,
str
):
anchor
=
np
.
load
(
anchor
)
elif
isinstance
(
anchor
,
(
list
,
tuple
)):
anchor
=
np
.
array
(
anchor
)
if
len
(
anchor
.
shape
)
==
3
:
# for map
anchor
=
anchor
.
reshape
(
anchor
.
shape
[
0
],
-
1
)
self
.
num_anchor
=
min
(
len
(
anchor
),
num_anchor
)
anchor
=
anchor
[:
num_anchor
]
self
.
anchor
=
nn
.
Parameter
(
torch
.
tensor
(
anchor
,
dtype
=
torch
.
float32
),
requires_grad
=
anchor_grad
,
)
self
.
anchor_init
=
anchor
self
.
instance_feature
=
nn
.
Parameter
(
torch
.
zeros
([
self
.
anchor
.
shape
[
0
],
self
.
embed_dims
]),
requires_grad
=
feat_grad
,
)
self
.
reset
()
def
init_weight
(
self
):
self
.
anchor
.
data
=
self
.
anchor
.
data
.
new_tensor
(
self
.
anchor_init
)
if
self
.
instance_feature
.
requires_grad
:
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
instance_feature
.
data
,
gain
=
1
)
def
reset
(
self
):
self
.
cached_feature
=
None
self
.
cached_anchor
=
None
self
.
metas
=
None
self
.
mask
=
None
self
.
confidence
=
None
self
.
temp_confidence
=
None
self
.
instance_id
=
None
self
.
prev_id
=
0
def
get
(
self
,
batch_size
,
metas
=
None
,
dn_metas
=
None
):
instance_feature
=
torch
.
tile
(
self
.
instance_feature
[
None
],
(
batch_size
,
1
,
1
)
)
anchor
=
torch
.
tile
(
self
.
anchor
[
None
],
(
batch_size
,
1
,
1
))
if
(
self
.
cached_anchor
is
not
None
and
batch_size
==
self
.
cached_anchor
.
shape
[
0
]
):
history_time
=
self
.
metas
[
"timestamp"
]
time_interval
=
metas
[
"timestamp"
]
-
history_time
time_interval
=
time_interval
.
to
(
dtype
=
instance_feature
.
dtype
)
self
.
mask
=
torch
.
abs
(
time_interval
)
<=
self
.
max_time_interval
if
self
.
anchor_handler
is
not
None
:
T_temp2cur
=
self
.
cached_anchor
.
new_tensor
(
np
.
stack
(
[
x
[
"T_global_inv"
]
@
self
.
metas
[
"img_metas"
][
i
][
"T_global"
]
for
i
,
x
in
enumerate
(
metas
[
"img_metas"
])
]
)
)
self
.
cached_anchor
=
self
.
anchor_handler
.
anchor_projection
(
self
.
cached_anchor
,
[
T_temp2cur
],
time_intervals
=
[
-
time_interval
],
)[
0
]
if
(
self
.
anchor_handler
is
not
None
and
dn_metas
is
not
None
and
batch_size
==
dn_metas
[
"dn_anchor"
].
shape
[
0
]
):
num_dn_group
,
num_dn
=
dn_metas
[
"dn_anchor"
].
shape
[
1
:
3
]
dn_anchor
=
self
.
anchor_handler
.
anchor_projection
(
dn_metas
[
"dn_anchor"
].
flatten
(
1
,
2
),
[
T_temp2cur
],
time_intervals
=
[
-
time_interval
],
)[
0
]
dn_metas
[
"dn_anchor"
]
=
dn_anchor
.
reshape
(
batch_size
,
num_dn_group
,
num_dn
,
-
1
)
time_interval
=
torch
.
where
(
torch
.
logical_and
(
time_interval
!=
0
,
self
.
mask
),
time_interval
,
time_interval
.
new_tensor
(
self
.
default_time_interval
),
)
else
:
self
.
reset
()
time_interval
=
instance_feature
.
new_tensor
(
[
self
.
default_time_interval
]
*
batch_size
)
return
(
instance_feature
,
anchor
,
self
.
cached_feature
,
self
.
cached_anchor
,
time_interval
,
)
def
update
(
self
,
instance_feature
,
anchor
,
confidence
):
if
self
.
cached_feature
is
None
:
return
instance_feature
,
anchor
num_dn
=
0
if
instance_feature
.
shape
[
1
]
>
self
.
num_anchor
:
num_dn
=
instance_feature
.
shape
[
1
]
-
self
.
num_anchor
dn_instance_feature
=
instance_feature
[:,
-
num_dn
:]
dn_anchor
=
anchor
[:,
-
num_dn
:]
instance_feature
=
instance_feature
[:,
:
self
.
num_anchor
]
anchor
=
anchor
[:,
:
self
.
num_anchor
]
confidence
=
confidence
[:,
:
self
.
num_anchor
]
N
=
self
.
num_anchor
-
self
.
num_temp_instances
confidence
=
confidence
.
max
(
dim
=-
1
).
values
_
,
(
selected_feature
,
selected_anchor
)
=
topk
(
confidence
,
N
,
instance_feature
,
anchor
)
selected_feature
=
torch
.
cat
(
[
self
.
cached_feature
,
selected_feature
],
dim
=
1
)
selected_anchor
=
torch
.
cat
(
[
self
.
cached_anchor
,
selected_anchor
],
dim
=
1
)
instance_feature
=
torch
.
where
(
self
.
mask
[:,
None
,
None
],
selected_feature
,
instance_feature
)
anchor
=
torch
.
where
(
self
.
mask
[:,
None
,
None
],
selected_anchor
,
anchor
)
self
.
confidence
=
torch
.
where
(
self
.
mask
[:,
None
],
self
.
confidence
,
self
.
confidence
.
new_tensor
(
0
)
)
if
self
.
instance_id
is
not
None
:
self
.
instance_id
=
torch
.
where
(
self
.
mask
[:,
None
],
self
.
instance_id
,
self
.
instance_id
.
new_tensor
(
-
1
),
)
if
num_dn
>
0
:
instance_feature
=
torch
.
cat
(
[
instance_feature
,
dn_instance_feature
],
dim
=
1
)
anchor
=
torch
.
cat
([
anchor
,
dn_anchor
],
dim
=
1
)
return
instance_feature
,
anchor
def
cache
(
self
,
instance_feature
,
anchor
,
confidence
,
metas
=
None
,
feature_maps
=
None
,
):
if
self
.
num_temp_instances
<=
0
:
return
instance_feature
=
instance_feature
.
detach
()
anchor
=
anchor
.
detach
()
confidence
=
confidence
.
detach
()
self
.
metas
=
metas
confidence
=
confidence
.
max
(
dim
=-
1
).
values
.
sigmoid
()
if
self
.
confidence
is
not
None
:
confidence
[:,
:
self
.
num_temp_instances
]
=
torch
.
maximum
(
self
.
confidence
*
self
.
confidence_decay
,
confidence
[:,
:
self
.
num_temp_instances
],
)
self
.
temp_confidence
=
confidence
(
self
.
confidence
,
(
self
.
cached_feature
,
self
.
cached_anchor
),
)
=
topk
(
confidence
,
self
.
num_temp_instances
,
instance_feature
,
anchor
)
def
get_instance_id
(
self
,
confidence
,
anchor
=
None
,
threshold
=
None
):
confidence
=
confidence
.
max
(
dim
=-
1
).
values
.
sigmoid
()
instance_id
=
confidence
.
new_full
(
confidence
.
shape
,
-
1
).
long
()
if
(
self
.
instance_id
is
not
None
and
self
.
instance_id
.
shape
[
0
]
==
instance_id
.
shape
[
0
]
):
instance_id
[:,
:
self
.
instance_id
.
shape
[
1
]]
=
self
.
instance_id
mask
=
instance_id
<
0
if
threshold
is
not
None
:
mask
=
mask
&
(
confidence
>=
threshold
)
num_new_instance
=
mask
.
sum
()
new_ids
=
torch
.
arange
(
num_new_instance
).
to
(
instance_id
)
+
self
.
prev_id
instance_id
[
torch
.
where
(
mask
)]
=
new_ids
self
.
prev_id
+=
num_new_instance
self
.
update_instance_id
(
instance_id
,
confidence
)
return
instance_id
def
update_instance_id
(
self
,
instance_id
=
None
,
confidence
=
None
):
if
self
.
temp_confidence
is
None
:
if
confidence
.
dim
()
==
3
:
# bs, num_anchor, num_cls
temp_conf
=
confidence
.
max
(
dim
=-
1
).
values
else
:
# bs, num_anchor
temp_conf
=
confidence
else
:
temp_conf
=
self
.
temp_confidence
instance_id
=
topk
(
temp_conf
,
self
.
num_temp_instances
,
instance_id
)[
1
][
0
]
instance_id
=
instance_id
.
squeeze
(
dim
=-
1
)
self
.
instance_id
=
F
.
pad
(
instance_id
,
(
0
,
self
.
num_anchor
-
self
.
num_temp_instances
),
value
=-
1
,
)
\ No newline at end of file
projects/mmdet3d_plugin/models/map/__init__.py
0 → 100644
View file @
afe88104
from
.decoder
import
SparsePoint3DDecoder
from
.target
import
SparsePoint3DTarget
,
HungarianLinesAssigner
from
.match_cost
import
LinesL1Cost
,
MapQueriesCost
from
.loss
import
LinesL1Loss
,
SparseLineLoss
from
.map_blocks
import
(
SparsePoint3DRefinementModule
,
SparsePoint3DKeyPointsGenerator
,
SparsePoint3DEncoder
,
)
\ No newline at end of file
projects/mmdet3d_plugin/models/map/decoder.py
0 → 100644
View file @
afe88104
from
typing
import
Optional
,
List
import
torch
from
mmdet.core.bbox.builder
import
BBOX_CODERS
@
BBOX_CODERS
.
register_module
()
class
SparsePoint3DDecoder
(
object
):
def
__init__
(
self
,
coords_dim
:
int
=
2
,
score_threshold
:
Optional
[
float
]
=
None
,
):
super
(
SparsePoint3DDecoder
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
coords_dim
=
coords_dim
def
decode
(
self
,
cls_scores
,
pts_preds
,
instance_id
=
None
,
quality
=
None
,
output_idx
=-
1
,
):
bs
,
num_pred
,
num_cls
=
cls_scores
[
-
1
].
shape
cls_scores
=
cls_scores
[
-
1
].
sigmoid
()
pts_preds
=
pts_preds
[
-
1
].
reshape
(
bs
,
num_pred
,
-
1
,
self
.
coords_dim
)
cls_scores
,
indices
=
cls_scores
.
flatten
(
start_dim
=
1
).
topk
(
num_pred
,
dim
=
1
)
cls_ids
=
indices
%
num_cls
if
self
.
score_threshold
is
not
None
:
mask
=
cls_scores
>=
self
.
score_threshold
output
=
[]
for
i
in
range
(
bs
):
category_ids
=
cls_ids
[
i
]
scores
=
cls_scores
[
i
]
pts
=
pts_preds
[
i
,
indices
[
i
]
//
num_cls
]
if
self
.
score_threshold
is
not
None
:
category_ids
=
category_ids
[
mask
[
i
]]
scores
=
scores
[
mask
[
i
]]
pts
=
pts
[
mask
[
i
]]
output
.
append
(
{
"vectors"
:
[
vec
.
detach
().
cpu
().
numpy
()
for
vec
in
pts
],
"scores"
:
scores
.
detach
().
cpu
().
numpy
(),
"labels"
:
category_ids
.
detach
().
cpu
().
numpy
(),
}
)
return
output
\ No newline at end of file
projects/mmdet3d_plugin/models/map/loss.py
0 → 100644
View file @
afe88104
import
torch
import
torch.nn
as
nn
from
mmcv.utils
import
build_from_cfg
from
mmdet.models.builder
import
LOSSES
from
mmdet.models.losses
import
l1_loss
,
smooth_l1_loss
@
LOSSES
.
register_module
()
class
LinesL1Loss
(
nn
.
Module
):
def
__init__
(
self
,
reduction
=
'mean'
,
loss_weight
=
1.0
,
beta
=
0.5
):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super
().
__init__
()
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
beta
=
beta
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
if
self
.
beta
>
0
:
loss
=
smooth_l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
beta
=
self
.
beta
)
else
:
loss
=
l1_loss
(
pred
,
target
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
num_points
=
pred
.
shape
[
-
1
]
//
2
loss
=
loss
/
num_points
return
loss
*
self
.
loss_weight
@
LOSSES
.
register_module
()
class
SparseLineLoss
(
nn
.
Module
):
def
__init__
(
self
,
loss_line
,
num_sample
=
20
,
roi_size
=
(
30
,
60
),
):
super
().
__init__
()
def
build
(
cfg
,
registry
):
if
cfg
is
None
:
return
None
return
build_from_cfg
(
cfg
,
registry
)
self
.
loss_line
=
build
(
loss_line
,
LOSSES
)
self
.
num_sample
=
num_sample
self
.
roi_size
=
roi_size
def
forward
(
self
,
line
,
line_target
,
weight
=
None
,
avg_factor
=
None
,
prefix
=
""
,
suffix
=
""
,
**
kwargs
,
):
output
=
{}
line
=
self
.
normalize_line
(
line
)
line_target
=
self
.
normalize_line
(
line_target
)
line_loss
=
self
.
loss_line
(
line
,
line_target
,
weight
=
weight
,
avg_factor
=
avg_factor
)
output
[
f
"
{
prefix
}
loss_line
{
suffix
}
"
]
=
line_loss
return
output
def
normalize_line
(
self
,
line
):
if
line
.
shape
[
0
]
==
0
:
return
line
line
=
line
.
view
(
line
.
shape
[:
-
1
]
+
(
self
.
num_sample
,
-
1
))
origin
=
-
line
.
new_tensor
([
self
.
roi_size
[
0
]
/
2
,
self
.
roi_size
[
1
]
/
2
])
line
=
line
-
origin
# transform from range [0, 1] to (0, 1)
eps
=
1e-5
norm
=
line
.
new_tensor
([
self
.
roi_size
[
0
],
self
.
roi_size
[
1
]])
+
eps
line
=
line
/
norm
line
=
line
.
flatten
(
-
2
,
-
1
)
return
line
projects/mmdet3d_plugin/models/map/map_blocks.py
0 → 100644
View file @
afe88104
from
typing
import
Optional
,
List
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
mmcv.cnn
import
Linear
,
Scale
,
bias_init_with_prob
from
mmcv.runner.base_module
import
Sequential
,
BaseModule
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.registry
import
(
PLUGIN_LAYERS
,
POSITIONAL_ENCODING
,
)
from
..blocks
import
linear_relu_ln
@
POSITIONAL_ENCODING
.
register_module
()
class
SparsePoint3DEncoder
(
BaseModule
):
def
__init__
(
self
,
embed_dims
:
int
=
256
,
num_sample
:
int
=
20
,
coords_dim
:
int
=
2
,
):
super
(
SparsePoint3DEncoder
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
input_dims
=
num_sample
*
coords_dim
def
embedding_layer
(
input_dims
):
return
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
,
input_dims
))
self
.
pos_fc
=
embedding_layer
(
self
.
input_dims
)
def
forward
(
self
,
anchor
:
torch
.
Tensor
):
pos_feat
=
self
.
pos_fc
(
anchor
)
return
pos_feat
@
PLUGIN_LAYERS
.
register_module
()
class
SparsePoint3DRefinementModule
(
BaseModule
):
def
__init__
(
self
,
embed_dims
:
int
=
256
,
num_sample
:
int
=
20
,
coords_dim
:
int
=
2
,
num_cls
:
int
=
3
,
with_cls_branch
:
bool
=
True
,
):
super
(
SparsePoint3DRefinementModule
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
num_sample
=
num_sample
self
.
output_dim
=
num_sample
*
coords_dim
self
.
num_cls
=
num_cls
self
.
layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
2
,
2
),
Linear
(
self
.
embed_dims
,
self
.
output_dim
),
Scale
([
1.0
]
*
self
.
output_dim
),
)
self
.
with_cls_branch
=
with_cls_branch
if
with_cls_branch
:
self
.
cls_layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
self
.
embed_dims
,
self
.
num_cls
),
)
def
init_weight
(
self
):
if
self
.
with_cls_branch
:
bias_init
=
bias_init_with_prob
(
0.01
)
nn
.
init
.
constant_
(
self
.
cls_layers
[
-
1
].
bias
,
bias_init
)
def
forward
(
self
,
instance_feature
:
torch
.
Tensor
,
anchor
:
torch
.
Tensor
,
anchor_embed
:
torch
.
Tensor
,
time_interval
:
torch
.
Tensor
=
1.0
,
return_cls
=
True
,
):
output
=
self
.
layers
(
instance_feature
+
anchor_embed
)
output
=
output
+
anchor
if
return_cls
:
assert
self
.
with_cls_branch
,
"Without classification layers !!!"
cls
=
self
.
cls_layers
(
instance_feature
)
## NOTE anchor embed?
else
:
cls
=
None
qt
=
None
return
output
,
cls
,
qt
@
PLUGIN_LAYERS
.
register_module
()
class
SparsePoint3DKeyPointsGenerator
(
BaseModule
):
def
__init__
(
self
,
embed_dims
:
int
=
256
,
num_sample
:
int
=
20
,
num_learnable_pts
:
int
=
0
,
fix_height
:
Tuple
=
(
0
,),
ground_height
:
int
=
0
,
):
super
(
SparsePoint3DKeyPointsGenerator
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
num_sample
=
num_sample
self
.
num_learnable_pts
=
num_learnable_pts
self
.
num_pts
=
num_sample
*
len
(
fix_height
)
*
num_learnable_pts
if
self
.
num_learnable_pts
>
0
:
self
.
learnable_fc
=
Linear
(
self
.
embed_dims
,
self
.
num_pts
*
2
)
self
.
fix_height
=
np
.
array
(
fix_height
)
self
.
ground_height
=
ground_height
def
init_weight
(
self
):
if
self
.
num_learnable_pts
>
0
:
xavier_init
(
self
.
learnable_fc
,
distribution
=
"uniform"
,
bias
=
0.0
)
def
forward
(
self
,
anchor
,
instance_feature
=
None
,
T_cur2temp_list
=
None
,
cur_timestamp
=
None
,
temp_timestamps
=
None
,
):
assert
self
.
num_learnable_pts
>
0
,
'No learnable pts'
bs
,
num_anchor
,
_
=
anchor
.
shape
key_points
=
anchor
.
view
(
bs
,
num_anchor
,
self
.
num_sample
,
-
1
)
offset
=
(
self
.
learnable_fc
(
instance_feature
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_sample
,
len
(
self
.
fix_height
),
self
.
num_learnable_pts
,
2
)
)
key_points
=
offset
+
key_points
[...,
None
,
None
,
:]
key_points
=
torch
.
cat
(
[
key_points
,
key_points
.
new_full
(
key_points
.
shape
[:
-
1
]
+
(
1
,),
fill_value
=
self
.
ground_height
),
],
dim
=-
1
,
)
fix_height
=
key_points
.
new_tensor
(
self
.
fix_height
)
height_offset
=
key_points
.
new_zeros
([
len
(
fix_height
),
2
])
height_offset
=
torch
.
cat
([
height_offset
,
fix_height
[:,
None
]],
dim
=-
1
)
key_points
=
key_points
+
height_offset
[
None
,
None
,
None
,
:,
None
]
key_points
=
key_points
.
flatten
(
2
,
4
)
if
(
cur_timestamp
is
None
or
temp_timestamps
is
None
or
T_cur2temp_list
is
None
or
len
(
temp_timestamps
)
==
0
):
return
key_points
temp_key_points_list
=
[]
for
i
,
t_time
in
enumerate
(
temp_timestamps
):
temp_key_points
=
key_points
T_cur2temp
=
T_cur2temp_list
[
i
].
to
(
dtype
=
key_points
.
dtype
)
temp_key_points
=
(
T_cur2temp
[:,
None
,
None
,
:
3
]
@
torch
.
cat
(
[
temp_key_points
,
torch
.
ones_like
(
temp_key_points
[...,
:
1
]),
],
dim
=-
1
,
).
unsqueeze
(
-
1
)
)
temp_key_points
=
temp_key_points
.
squeeze
(
-
1
)
temp_key_points_list
.
append
(
temp_key_points
)
return
key_points
,
temp_key_points_list
# @staticmethod
def
anchor_projection
(
self
,
anchor
,
T_src2dst_list
,
src_timestamp
=
None
,
dst_timestamps
=
None
,
time_intervals
=
None
,
):
dst_anchors
=
[]
for
i
in
range
(
len
(
T_src2dst_list
)):
dst_anchor
=
anchor
.
clone
()
bs
,
num_anchor
,
_
=
anchor
.
shape
dst_anchor
=
dst_anchor
.
reshape
(
bs
,
num_anchor
,
self
.
num_sample
,
-
1
).
flatten
(
1
,
2
)
T_src2dst
=
torch
.
unsqueeze
(
T_src2dst_list
[
i
].
to
(
dtype
=
anchor
.
dtype
),
dim
=
1
)
dst_anchor
=
(
torch
.
matmul
(
T_src2dst
[...,
:
2
,
:
2
],
dst_anchor
[...,
None
]
).
squeeze
(
dim
=-
1
)
+
T_src2dst
[...,
:
2
,
3
]
)
dst_anchor
=
dst_anchor
.
reshape
(
bs
,
num_anchor
,
self
.
num_sample
,
-
1
).
flatten
(
2
,
3
)
dst_anchors
.
append
(
dst_anchor
)
return
dst_anchors
\ No newline at end of file
projects/mmdet3d_plugin/models/map/match_cost.py
0 → 100644
View file @
afe88104
import
torch
from
mmdet.core.bbox.match_costs.builder
import
MATCH_COST
from
mmdet.core.bbox.match_costs
import
build_match_cost
from
torch.nn.functional
import
smooth_l1_loss
@
MATCH_COST
.
register_module
()
class
LinesL1Cost
(
object
):
"""LinesL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def
__init__
(
self
,
weight
=
1.0
,
beta
=
0.0
,
permute
=
False
):
self
.
weight
=
weight
self
.
permute
=
permute
self
.
beta
=
beta
def
__call__
(
self
,
lines_pred
,
gt_lines
,
**
kwargs
):
"""
Args:
lines_pred (Tensor): predicted normalized lines:
[num_query, 2*num_points]
gt_lines (Tensor): Ground truth lines
[num_gt, 2*num_points] or [num_gt, num_permute, 2*num_points]
Returns:
torch.Tensor: reg_cost value with weight
shape [num_pred, num_gt]
"""
if
self
.
permute
:
assert
len
(
gt_lines
.
shape
)
==
3
else
:
assert
len
(
gt_lines
.
shape
)
==
2
num_pred
,
num_gt
=
len
(
lines_pred
),
len
(
gt_lines
)
if
self
.
permute
:
# permute-invarint labels
gt_lines
=
gt_lines
.
flatten
(
0
,
1
)
# (num_gt*num_permute, 2*num_pts)
num_pts
=
lines_pred
.
shape
[
-
1
]
//
2
if
self
.
beta
>
0
:
lines_pred
=
lines_pred
.
unsqueeze
(
1
).
repeat
(
1
,
len
(
gt_lines
),
1
)
gt_lines
=
gt_lines
.
unsqueeze
(
0
).
repeat
(
num_pred
,
1
,
1
)
dist_mat
=
smooth_l1_loss
(
lines_pred
,
gt_lines
,
reduction
=
'none'
,
beta
=
self
.
beta
).
sum
(
-
1
)
else
:
dist_mat
=
torch
.
cdist
(
lines_pred
,
gt_lines
,
p
=
1
)
dist_mat
=
dist_mat
/
num_pts
if
self
.
permute
:
# dist_mat: (num_pred, num_gt*num_permute)
dist_mat
=
dist_mat
.
view
(
num_pred
,
num_gt
,
-
1
)
# (num_pred, num_gt, num_permute)
dist_mat
,
gt_permute_index
=
torch
.
min
(
dist_mat
,
2
)
return
dist_mat
*
self
.
weight
,
gt_permute_index
return
dist_mat
*
self
.
weight
@
MATCH_COST
.
register_module
()
class
MapQueriesCost
(
object
):
def
__init__
(
self
,
cls_cost
,
reg_cost
,
iou_cost
=
None
):
self
.
cls_cost
=
build_match_cost
(
cls_cost
)
self
.
reg_cost
=
build_match_cost
(
reg_cost
)
self
.
iou_cost
=
None
if
iou_cost
is
not
None
:
self
.
iou_cost
=
build_match_cost
(
iou_cost
)
def
__call__
(
self
,
preds
:
dict
,
gts
:
dict
,
ignore_cls_cost
:
bool
):
# classification and bboxcost.
cls_cost
=
self
.
cls_cost
(
preds
[
'scores'
],
gts
[
'labels'
])
# regression cost
regkwargs
=
{}
if
'masks'
in
preds
and
'masks'
in
gts
:
assert
isinstance
(
self
.
reg_cost
,
DynamicLinesCost
),
' Issues!!'
regkwargs
=
{
'masks_pred'
:
preds
[
'masks'
],
'masks_gt'
:
gts
[
'masks'
],
}
reg_cost
=
self
.
reg_cost
(
preds
[
'lines'
],
gts
[
'lines'
],
**
regkwargs
)
if
self
.
reg_cost
.
permute
:
reg_cost
,
gt_permute_idx
=
reg_cost
# weighted sum of above three costs
if
ignore_cls_cost
:
cost
=
reg_cost
else
:
cost
=
cls_cost
+
reg_cost
# Iou
if
self
.
iou_cost
is
not
None
:
iou_cost
=
self
.
iou_cost
(
preds
[
'lines'
],
gts
[
'lines'
])
cost
+=
iou_cost
if
self
.
reg_cost
.
permute
:
return
cost
,
gt_permute_idx
return
cost
projects/mmdet3d_plugin/models/map/target.py
0 → 100644
View file @
afe88104
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
scipy.optimize
import
linear_sum_assignment
from
mmdet.core.bbox.builder
import
(
BBOX_SAMPLERS
,
BBOX_ASSIGNERS
)
from
mmdet.core.bbox.match_costs
import
build_match_cost
from
mmdet.core
import
(
build_assigner
,
build_sampler
)
from
mmdet.core.bbox.assigners
import
(
AssignResult
,
BaseAssigner
)
from
..base_target
import
BaseTargetWithDenoising
@
BBOX_SAMPLERS
.
register_module
()
class
SparsePoint3DTarget
(
BaseTargetWithDenoising
):
def
__init__
(
self
,
assigner
=
None
,
num_dn_groups
=
0
,
dn_noise_scale
=
0.5
,
max_dn_gt
=
32
,
add_neg_dn
=
True
,
num_temp_dn_groups
=
0
,
num_cls
=
3
,
num_sample
=
20
,
roi_size
=
(
30
,
60
),
):
super
(
SparsePoint3DTarget
,
self
).
__init__
(
num_dn_groups
,
num_temp_dn_groups
)
self
.
assigner
=
build_assigner
(
assigner
)
self
.
dn_noise_scale
=
dn_noise_scale
self
.
max_dn_gt
=
max_dn_gt
self
.
add_neg_dn
=
add_neg_dn
self
.
num_cls
=
num_cls
self
.
num_sample
=
num_sample
self
.
roi_size
=
roi_size
def
sample
(
self
,
cls_preds
,
pts_preds
,
cls_targets
,
pts_targets
,
):
pts_targets
=
[
x
.
flatten
(
2
,
3
)
if
len
(
x
.
shape
)
==
4
else
x
for
x
in
pts_targets
]
indices
=
[]
for
(
cls_pred
,
pts_pred
,
cls_target
,
pts_target
)
in
zip
(
cls_preds
,
pts_preds
,
cls_targets
,
pts_targets
):
# normalize to (0, 1)
pts_pred
=
self
.
normalize_line
(
pts_pred
)
pts_target
=
self
.
normalize_line
(
pts_target
)
preds
=
dict
(
lines
=
pts_pred
,
scores
=
cls_pred
)
gts
=
dict
(
lines
=
pts_target
,
labels
=
cls_target
)
indice
=
self
.
assigner
.
assign
(
preds
,
gts
)
indices
.
append
(
indice
)
bs
,
num_pred
,
num_cls
=
cls_preds
.
shape
output_cls_target
=
cls_targets
[
0
].
new_ones
([
bs
,
num_pred
],
dtype
=
torch
.
long
)
*
num_cls
output_box_target
=
pts_preds
.
new_zeros
(
pts_preds
.
shape
)
output_reg_weights
=
pts_preds
.
new_zeros
(
pts_preds
.
shape
)
for
i
,
(
pred_idx
,
target_idx
,
gt_permute_index
)
in
enumerate
(
indices
):
if
len
(
cls_targets
[
i
])
==
0
:
continue
permute_idx
=
gt_permute_index
[
pred_idx
,
target_idx
]
output_cls_target
[
i
,
pred_idx
]
=
cls_targets
[
i
][
target_idx
]
output_box_target
[
i
,
pred_idx
]
=
pts_targets
[
i
][
target_idx
,
permute_idx
]
output_reg_weights
[
i
,
pred_idx
]
=
1
return
output_cls_target
,
output_box_target
,
output_reg_weights
def
normalize_line
(
self
,
line
):
if
line
.
shape
[
0
]
==
0
:
return
line
line
=
line
.
view
(
line
.
shape
[:
-
1
]
+
(
self
.
num_sample
,
-
1
))
origin
=
-
line
.
new_tensor
([
self
.
roi_size
[
0
]
/
2
,
self
.
roi_size
[
1
]
/
2
])
line
=
line
-
origin
# transform from range [0, 1] to (0, 1)
eps
=
1e-5
norm
=
line
.
new_tensor
([
self
.
roi_size
[
0
],
self
.
roi_size
[
1
]])
+
eps
line
=
line
/
norm
line
=
line
.
flatten
(
-
2
,
-
1
)
return
line
@
BBOX_ASSIGNERS
.
register_module
()
class
HungarianLinesAssigner
(
BaseAssigner
):
"""
Computes one-to-one matching between predictions and ground truth.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost and regression L1 cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
bbox_weight (int | float, optional): The scale factor for regression
L1 cost. Default 1.0.
"""
def
__init__
(
self
,
cost
=
dict
,
**
kwargs
):
self
.
cost
=
build_match_cost
(
cost
)
def
assign
(
self
,
preds
:
dict
,
gts
:
dict
,
ignore_cls_cost
=
False
,
gt_bboxes_ignore
=
None
,
eps
=
1e-7
):
"""
Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
Args:
lines_pred (Tensor): predicted normalized lines:
[num_query, num_points, 2]
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
lines_gt (Tensor): Ground truth lines
[num_gt, num_points, 2].
labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
assert
gt_bboxes_ignore
is
None
,
\
'Only case when gt_bboxes_ignore is None is supported.'
num_gts
,
num_lines
=
gts
[
'lines'
].
size
(
0
),
preds
[
'lines'
].
size
(
0
)
if
num_gts
==
0
or
num_lines
==
0
:
return
None
,
None
,
None
# compute the weighted costs
gt_permute_idx
=
None
# (num_preds, num_gts)
if
self
.
cost
.
reg_cost
.
permute
:
cost
,
gt_permute_idx
=
self
.
cost
(
preds
,
gts
,
ignore_cls_cost
)
else
:
cost
=
self
.
cost
(
preds
,
gts
,
ignore_cls_cost
)
# do Hungarian matching on CPU using linear_sum_assignment
cost
=
cost
.
detach
().
cpu
().
numpy
()
matched_row_inds
,
matched_col_inds
=
linear_sum_assignment
(
cost
)
return
matched_row_inds
,
matched_col_inds
,
gt_permute_idx
\ No newline at end of file
projects/mmdet3d_plugin/models/motion/__init__.py
0 → 100644
View file @
afe88104
from
.motion_planning_head
import
MotionPlanningHead
from
.motion_blocks
import
MotionPlanningRefinementModule
from
.instance_queue
import
InstanceQueue
from
.target
import
MotionTarget
,
PlanningTarget
from
.decoder
import
SparseBox3DMotionDecoder
,
HierarchicalPlanningDecoder
projects/mmdet3d_plugin/models/motion/decoder.py
0 → 100644
View file @
afe88104
from
typing
import
Optional
import
numpy
as
np
import
torch
from
mmdet.core.bbox.builder
import
BBOX_CODERS
from
projects.mmdet3d_plugin.core.box3d
import
*
from
projects.mmdet3d_plugin.models.detection3d.decoder
import
*
from
projects.mmdet3d_plugin.datasets.utils
import
box3d_to_corners
@
BBOX_CODERS
.
register_module
()
class
SparseBox3DMotionDecoder
(
SparseBox3DDecoder
):
def
__init__
(
self
):
super
(
SparseBox3DMotionDecoder
,
self
).
__init__
()
def
decode
(
self
,
cls_scores
,
box_preds
,
instance_id
=
None
,
quality
=
None
,
motion_output
=
None
,
output_idx
=-
1
,
):
squeeze_cls
=
instance_id
is
not
None
cls_scores
=
cls_scores
[
output_idx
].
sigmoid
()
if
squeeze_cls
:
cls_scores
,
cls_ids
=
cls_scores
.
max
(
dim
=-
1
)
cls_scores
=
cls_scores
.
unsqueeze
(
dim
=-
1
)
box_preds
=
box_preds
[
output_idx
]
bs
,
num_pred
,
num_cls
=
cls_scores
.
shape
cls_scores
,
indices
=
cls_scores
.
flatten
(
start_dim
=
1
).
topk
(
self
.
num_output
,
dim
=
1
,
sorted
=
self
.
sorted
)
if
not
squeeze_cls
:
cls_ids
=
indices
%
num_cls
if
self
.
score_threshold
is
not
None
:
mask
=
cls_scores
>=
self
.
score_threshold
if
quality
[
output_idx
]
is
None
:
quality
=
None
if
quality
is
not
None
:
centerness
=
quality
[
output_idx
][...,
CNS
]
centerness
=
torch
.
gather
(
centerness
,
1
,
indices
//
num_cls
)
cls_scores_origin
=
cls_scores
.
clone
()
cls_scores
*=
centerness
.
sigmoid
()
cls_scores
,
idx
=
torch
.
sort
(
cls_scores
,
dim
=
1
,
descending
=
True
)
if
not
squeeze_cls
:
cls_ids
=
torch
.
gather
(
cls_ids
,
1
,
idx
)
if
self
.
score_threshold
is
not
None
:
mask
=
torch
.
gather
(
mask
,
1
,
idx
)
indices
=
torch
.
gather
(
indices
,
1
,
idx
)
output
=
[]
anchor_queue
=
motion_output
[
"anchor_queue"
]
anchor_queue
=
torch
.
stack
(
anchor_queue
,
dim
=
2
)
period
=
motion_output
[
"period"
]
for
i
in
range
(
bs
):
category_ids
=
cls_ids
[
i
]
if
squeeze_cls
:
category_ids
=
category_ids
[
indices
[
i
]]
scores
=
cls_scores
[
i
]
box
=
box_preds
[
i
,
indices
[
i
]
//
num_cls
]
if
self
.
score_threshold
is
not
None
:
category_ids
=
category_ids
[
mask
[
i
]]
scores
=
scores
[
mask
[
i
]]
box
=
box
[
mask
[
i
]]
if
quality
is
not
None
:
scores_origin
=
cls_scores_origin
[
i
]
if
self
.
score_threshold
is
not
None
:
scores_origin
=
scores_origin
[
mask
[
i
]]
box
=
decode_box
(
box
)
trajs
=
motion_output
[
"prediction"
][
-
1
]
traj_cls
=
motion_output
[
"classification"
][
-
1
].
sigmoid
()
traj
=
trajs
[
i
,
indices
[
i
]
//
num_cls
]
traj_cls
=
traj_cls
[
i
,
indices
[
i
]
//
num_cls
]
if
self
.
score_threshold
is
not
None
:
traj
=
traj
[
mask
[
i
]]
traj_cls
=
traj_cls
[
mask
[
i
]]
traj
=
traj
.
cumsum
(
dim
=-
2
)
+
box
[:,
None
,
None
,
:
2
]
output
.
append
(
{
"trajs_3d"
:
traj
.
cpu
(),
"trajs_score"
:
traj_cls
.
cpu
()
}
)
temp_anchor
=
anchor_queue
[
i
,
indices
[
i
]
//
num_cls
]
temp_period
=
period
[
i
,
indices
[
i
]
//
num_cls
]
if
self
.
score_threshold
is
not
None
:
temp_anchor
=
temp_anchor
[
mask
[
i
]]
temp_period
=
temp_period
[
mask
[
i
]]
num_pred
,
queue_len
=
temp_anchor
.
shape
[:
2
]
temp_anchor
=
temp_anchor
.
flatten
(
0
,
1
)
temp_anchor
=
decode_box
(
temp_anchor
)
temp_anchor
=
temp_anchor
.
reshape
([
num_pred
,
queue_len
,
box
.
shape
[
-
1
]])
output
[
-
1
][
'anchor_queue'
]
=
temp_anchor
.
cpu
()
output
[
-
1
][
'period'
]
=
temp_period
.
cpu
()
return
output
@
BBOX_CODERS
.
register_module
()
class
HierarchicalPlanningDecoder
(
object
):
def
__init__
(
self
,
ego_fut_ts
,
ego_fut_mode
,
use_rescore
=
False
,
):
super
(
HierarchicalPlanningDecoder
,
self
).
__init__
()
self
.
ego_fut_ts
=
ego_fut_ts
self
.
ego_fut_mode
=
ego_fut_mode
self
.
use_rescore
=
use_rescore
def
decode
(
self
,
det_output
,
motion_output
,
planning_output
,
data
,
):
classification
=
planning_output
[
'classification'
][
-
1
]
prediction
=
planning_output
[
'prediction'
][
-
1
]
bs
=
classification
.
shape
[
0
]
classification
=
classification
.
reshape
(
bs
,
3
,
self
.
ego_fut_mode
)
prediction
=
prediction
.
reshape
(
bs
,
3
,
self
.
ego_fut_mode
,
self
.
ego_fut_ts
,
2
).
cumsum
(
dim
=-
2
)
classification
,
final_planning
=
self
.
select
(
det_output
,
motion_output
,
classification
,
prediction
,
data
)
anchor_queue
=
planning_output
[
"anchor_queue"
]
anchor_queue
=
torch
.
stack
(
anchor_queue
,
dim
=
2
)
period
=
planning_output
[
"period"
]
output
=
[]
for
i
,
(
cls
,
pred
)
in
enumerate
(
zip
(
classification
,
prediction
)):
output
.
append
(
{
"planning_score"
:
cls
.
sigmoid
().
cpu
(),
"planning"
:
pred
.
cpu
(),
"final_planning"
:
final_planning
[
i
].
cpu
(),
"ego_period"
:
period
[
i
].
cpu
(),
"ego_anchor_queue"
:
decode_box
(
anchor_queue
[
i
]).
cpu
(),
}
)
return
output
def
select
(
self
,
det_output
,
motion_output
,
plan_cls
,
plan_reg
,
data
,
):
det_classification
=
det_output
[
"classification"
][
-
1
].
sigmoid
()
det_anchors
=
det_output
[
"prediction"
][
-
1
]
det_confidence
=
det_classification
.
max
(
dim
=-
1
).
values
motion_cls
=
motion_output
[
"classification"
][
-
1
].
sigmoid
()
motion_reg
=
motion_output
[
"prediction"
][
-
1
]
# cmd select
bs
=
motion_cls
.
shape
[
0
]
bs_indices
=
torch
.
arange
(
bs
,
device
=
motion_cls
.
device
)
cmd
=
data
[
'gt_ego_fut_cmd'
].
argmax
(
dim
=-
1
)
plan_cls_full
=
plan_cls
.
detach
().
clone
()
plan_cls
=
plan_cls
[
bs_indices
,
cmd
]
plan_reg
=
plan_reg
[
bs_indices
,
cmd
]
# rescore
if
self
.
use_rescore
:
plan_cls
=
self
.
rescore
(
plan_cls
,
plan_reg
,
motion_cls
,
motion_reg
,
det_anchors
,
det_confidence
,
)
plan_cls_full
[
bs_indices
,
cmd
]
=
plan_cls
mode_idx
=
plan_cls
.
argmax
(
dim
=-
1
)
final_planning
=
plan_reg
[
bs_indices
,
mode_idx
]
return
plan_cls_full
,
final_planning
def
rescore
(
self
,
plan_cls
,
plan_reg
,
motion_cls
,
motion_reg
,
det_anchors
,
det_confidence
,
score_thresh
=
0.5
,
static_dis_thresh
=
0.5
,
dim_scale
=
1.1
,
num_motion_mode
=
1
,
offset
=
0.5
,
):
def
cat_with_zero
(
traj
):
zeros
=
traj
.
new_zeros
(
traj
.
shape
[:
-
2
]
+
(
1
,
2
))
traj_cat
=
torch
.
cat
([
zeros
,
traj
],
dim
=-
2
)
return
traj_cat
def
get_yaw
(
traj
,
start_yaw
=
np
.
pi
/
2
):
yaw
=
traj
.
new_zeros
(
traj
.
shape
[:
-
1
])
yaw
[...,
1
:
-
1
]
=
torch
.
atan2
(
traj
[...,
2
:,
1
]
-
traj
[...,
:
-
2
,
1
],
traj
[...,
2
:,
0
]
-
traj
[...,
:
-
2
,
0
],
)
yaw
[...,
-
1
]
=
torch
.
atan2
(
traj
[...,
-
1
,
1
]
-
traj
[...,
-
2
,
1
],
traj
[...,
-
1
,
0
]
-
traj
[...,
-
2
,
0
],
)
yaw
[...,
0
]
=
start_yaw
# for static object, estimated future yaw would be unstable
start
=
traj
[...,
0
,
:]
end
=
traj
[...,
-
1
,
:]
dist
=
torch
.
linalg
.
norm
(
end
-
start
,
dim
=-
1
)
mask
=
dist
<
static_dis_thresh
start_yaw
=
yaw
[...,
0
].
unsqueeze
(
-
1
)
yaw
=
torch
.
where
(
mask
.
unsqueeze
(
-
1
),
start_yaw
,
yaw
,
)
return
yaw
.
unsqueeze
(
-
1
)
## ego
bs
=
plan_reg
.
shape
[
0
]
plan_reg_cat
=
cat_with_zero
(
plan_reg
)
ego_box
=
det_anchors
.
new_zeros
(
bs
,
self
.
ego_fut_mode
,
self
.
ego_fut_ts
+
1
,
7
)
ego_box
[...,
[
X
,
Y
]]
=
plan_reg_cat
ego_box
[...,
[
W
,
L
,
H
]]
=
ego_box
.
new_tensor
([
4.08
,
1.73
,
1.56
])
*
dim_scale
ego_box
[...,
[
YAW
]]
=
get_yaw
(
plan_reg_cat
)
## motion
motion_reg
=
motion_reg
[...,
:
self
.
ego_fut_ts
,
:].
cumsum
(
-
2
)
motion_reg
=
cat_with_zero
(
motion_reg
)
+
det_anchors
[:,
:,
None
,
None
,
:
2
]
_
,
motion_mode_idx
=
torch
.
topk
(
motion_cls
,
num_motion_mode
,
dim
=-
1
)
motion_mode_idx
=
motion_mode_idx
[...,
None
,
None
].
repeat
(
1
,
1
,
1
,
self
.
ego_fut_ts
+
1
,
2
)
motion_reg
=
torch
.
gather
(
motion_reg
,
2
,
motion_mode_idx
)
motion_box
=
motion_reg
.
new_zeros
(
motion_reg
.
shape
[:
-
1
]
+
(
7
,))
motion_box
[...,
[
X
,
Y
]]
=
motion_reg
motion_box
[...,
[
W
,
L
,
H
]]
=
det_anchors
[...,
None
,
None
,
[
W
,
L
,
H
]].
exp
()
box_yaw
=
torch
.
atan2
(
det_anchors
[...,
SIN_YAW
],
det_anchors
[...,
COS_YAW
],
)
motion_box
[...,
[
YAW
]]
=
get_yaw
(
motion_reg
,
box_yaw
.
unsqueeze
(
-
1
))
filter_mask
=
det_confidence
<
score_thresh
motion_box
[
filter_mask
]
=
1e6
ego_box
=
ego_box
[...,
1
:,
:]
motion_box
=
motion_box
[...,
1
:,
:]
bs
,
num_ego_mode
,
ts
,
_
=
ego_box
.
shape
bs
,
num_anchor
,
num_motion_mode
,
ts
,
_
=
motion_box
.
shape
ego_box
=
ego_box
[:,
None
,
None
].
repeat
(
1
,
num_anchor
,
num_motion_mode
,
1
,
1
,
1
).
flatten
(
0
,
-
2
)
motion_box
=
motion_box
.
unsqueeze
(
3
).
repeat
(
1
,
1
,
1
,
num_ego_mode
,
1
,
1
).
flatten
(
0
,
-
2
)
ego_box
[
0
]
+=
offset
*
torch
.
cos
(
ego_box
[
6
])
ego_box
[
1
]
+=
offset
*
torch
.
sin
(
ego_box
[
6
])
col
=
check_collision
(
ego_box
,
motion_box
)
col
=
col
.
reshape
(
bs
,
num_anchor
,
num_motion_mode
,
num_ego_mode
,
ts
).
permute
(
0
,
3
,
1
,
2
,
4
)
col
=
col
.
flatten
(
2
,
-
1
).
any
(
dim
=-
1
)
all_col
=
col
.
all
(
dim
=-
1
)
col
[
all_col
]
=
False
# for case that all modes collide, no need to rescore
score_offset
=
col
.
float
()
*
-
999
plan_cls
=
plan_cls
+
score_offset
return
plan_cls
def
check_collision
(
boxes1
,
boxes2
):
'''
A rough check for collision detection:
check if any corner point of boxes1 is inside boxes2 and vice versa.
boxes1: tensor with shape [N, 7], [x, y, z, w, l, h, yaw]
boxes2: tensor with shape [N, 7]
'''
col_1
=
corners_in_box
(
boxes1
.
clone
(),
boxes2
.
clone
())
col_2
=
corners_in_box
(
boxes2
.
clone
(),
boxes1
.
clone
())
collision
=
torch
.
logical_or
(
col_1
,
col_2
)
return
collision
def
corners_in_box
(
boxes1
,
boxes2
):
if
boxes1
.
shape
[
0
]
==
0
or
boxes2
.
shape
[
0
]
==
0
:
return
False
boxes1_yaw
=
boxes1
[:,
6
].
clone
()
boxes1_loc
=
boxes1
[:,
:
3
].
clone
()
cos_yaw
=
torch
.
cos
(
-
boxes1_yaw
)
sin_yaw
=
torch
.
sin
(
-
boxes1_yaw
)
rot_mat_T
=
torch
.
stack
(
[
torch
.
stack
([
cos_yaw
,
sin_yaw
]),
torch
.
stack
([
-
sin_yaw
,
cos_yaw
]),
]
)
# translate and rotate boxes
boxes1
[:,
:
3
]
=
boxes1
[:,
:
3
]
-
boxes1_loc
boxes1
[:,
:
2
]
=
torch
.
einsum
(
'ij,jki->ik'
,
boxes1
[:,
:
2
],
rot_mat_T
)
boxes1
[:,
6
]
=
boxes1
[:,
6
]
-
boxes1_yaw
boxes2
[:,
:
3
]
=
boxes2
[:,
:
3
]
-
boxes1_loc
boxes2
[:,
:
2
]
=
torch
.
einsum
(
'ij,jki->ik'
,
boxes2
[:,
:
2
],
rot_mat_T
)
boxes2
[:,
6
]
=
boxes2
[:,
6
]
-
boxes1_yaw
corners_box2
=
box3d_to_corners
(
boxes2
)[:,
[
0
,
3
,
7
,
4
],
:
2
]
corners_box2
=
torch
.
from_numpy
(
corners_box2
).
to
(
boxes2
.
device
)
H
=
boxes1
[:,
[
3
]]
W
=
boxes1
[:,
[
4
]]
collision
=
torch
.
logical_and
(
torch
.
logical_and
(
corners_box2
[...,
0
]
<=
H
/
2
,
corners_box2
[...,
0
]
>=
-
H
/
2
),
torch
.
logical_and
(
corners_box2
[...,
1
]
<=
W
/
2
,
corners_box2
[...,
1
]
>=
-
W
/
2
),
)
collision
=
collision
.
any
(
dim
=-
1
)
return
collision
\ No newline at end of file
projects/mmdet3d_plugin/models/motion/instance_queue.py
0 → 100644
View file @
afe88104
import
copy
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
mmcv.utils
import
build_from_cfg
from
mmcv.cnn.bricks.registry
import
PLUGIN_LAYERS
from
projects.mmdet3d_plugin.ops
import
feature_maps_format
from
projects.mmdet3d_plugin.core.box3d
import
*
@
PLUGIN_LAYERS
.
register_module
()
class
InstanceQueue
(
nn
.
Module
):
def
__init__
(
self
,
embed_dims
,
queue_length
=
0
,
tracking_threshold
=
0
,
feature_map_scale
=
None
,
):
super
(
InstanceQueue
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
queue_length
=
queue_length
self
.
tracking_threshold
=
tracking_threshold
kernel_size
=
tuple
([
int
(
x
/
2
)
for
x
in
feature_map_scale
])
self
.
ego_feature_encoder
=
nn
.
Sequential
(
nn
.
Conv2d
(
embed_dims
,
embed_dims
,
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
embed_dims
),
nn
.
Conv2d
(
embed_dims
,
embed_dims
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
embed_dims
),
nn
.
ReLU
(),
nn
.
AvgPool2d
(
kernel_size
),
)
self
.
ego_anchor
=
nn
.
Parameter
(
torch
.
tensor
([[
0
,
0.5
,
-
1.84
+
1.56
/
2
,
np
.
log
(
4.08
),
np
.
log
(
1.73
),
np
.
log
(
1.56
),
1
,
0
,
0
,
0
,
0
],],
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
reset
()
def
reset
(
self
):
self
.
metas
=
None
self
.
prev_instance_id
=
None
self
.
prev_confidence
=
None
self
.
period
=
None
self
.
instance_feature_queue
=
[]
self
.
anchor_queue
=
[]
self
.
prev_ego_status
=
None
self
.
ego_period
=
None
self
.
ego_feature_queue
=
[]
self
.
ego_anchor_queue
=
[]
def
get
(
self
,
det_output
,
feature_maps
,
metas
,
batch_size
,
mask
,
anchor_handler
,
):
if
(
self
.
period
is
not
None
and
batch_size
==
self
.
period
.
shape
[
0
]
):
if
anchor_handler
is
not
None
:
T_temp2cur
=
feature_maps
[
0
].
new_tensor
(
np
.
stack
(
[
x
[
"T_global_inv"
]
@
self
.
metas
[
"img_metas"
][
i
][
"T_global"
]
for
i
,
x
in
enumerate
(
metas
[
"img_metas"
])
]
)
)
for
i
in
range
(
len
(
self
.
anchor_queue
)):
temp_anchor
=
self
.
anchor_queue
[
i
]
temp_anchor
=
anchor_handler
.
anchor_projection
(
temp_anchor
,
[
T_temp2cur
],
)[
0
]
self
.
anchor_queue
[
i
]
=
temp_anchor
for
i
in
range
(
len
(
self
.
ego_anchor_queue
)):
temp_anchor
=
self
.
ego_anchor_queue
[
i
]
temp_anchor
=
anchor_handler
.
anchor_projection
(
temp_anchor
,
[
T_temp2cur
],
)[
0
]
self
.
ego_anchor_queue
[
i
]
=
temp_anchor
else
:
self
.
reset
()
self
.
prepare_motion
(
det_output
,
mask
)
ego_feature
,
ego_anchor
=
self
.
prepare_planning
(
feature_maps
,
mask
,
batch_size
)
# temporal
temp_instance_feature
=
torch
.
stack
(
self
.
instance_feature_queue
,
dim
=
2
)
temp_anchor
=
torch
.
stack
(
self
.
anchor_queue
,
dim
=
2
)
temp_ego_feature
=
torch
.
stack
(
self
.
ego_feature_queue
,
dim
=
2
)
temp_ego_anchor
=
torch
.
stack
(
self
.
ego_anchor_queue
,
dim
=
2
)
period
=
torch
.
cat
([
self
.
period
,
self
.
ego_period
],
dim
=
1
)
temp_instance_feature
=
torch
.
cat
([
temp_instance_feature
,
temp_ego_feature
],
dim
=
1
)
temp_anchor
=
torch
.
cat
([
temp_anchor
,
temp_ego_anchor
],
dim
=
1
)
num_agent
=
temp_anchor
.
shape
[
1
]
temp_mask
=
torch
.
arange
(
len
(
self
.
anchor_queue
),
0
,
-
1
,
device
=
temp_anchor
.
device
)
temp_mask
=
temp_mask
[
None
,
None
].
repeat
((
batch_size
,
num_agent
,
1
))
temp_mask
=
torch
.
gt
(
temp_mask
,
period
[...,
None
])
return
ego_feature
,
ego_anchor
,
temp_instance_feature
,
temp_anchor
,
temp_mask
def
prepare_motion
(
self
,
det_output
,
mask
,
):
instance_feature
=
det_output
[
"instance_feature"
]
det_anchors
=
det_output
[
"prediction"
][
-
1
]
if
self
.
period
==
None
:
self
.
period
=
instance_feature
.
new_zeros
(
instance_feature
.
shape
[:
2
]).
long
()
else
:
instance_id
=
det_output
[
'instance_id'
]
prev_instance_id
=
self
.
prev_instance_id
match
=
instance_id
[...,
None
]
==
prev_instance_id
[:,
None
]
if
self
.
tracking_threshold
>
0
:
temp_mask
=
self
.
prev_confidence
>
self
.
tracking_threshold
match
=
match
*
temp_mask
.
unsqueeze
(
1
)
for
i
in
range
(
len
(
self
.
instance_feature_queue
)):
temp_feature
=
self
.
instance_feature_queue
[
i
]
temp_feature
=
(
match
[...,
None
]
*
temp_feature
[:,
None
]
).
sum
(
dim
=
2
)
self
.
instance_feature_queue
[
i
]
=
temp_feature
temp_anchor
=
self
.
anchor_queue
[
i
]
temp_anchor
=
(
match
[...,
None
]
*
temp_anchor
[:,
None
]
).
sum
(
dim
=
2
)
self
.
anchor_queue
[
i
]
=
temp_anchor
self
.
period
=
(
match
*
self
.
period
[:,
None
]
).
sum
(
dim
=
2
)
self
.
instance_feature_queue
.
append
(
instance_feature
.
detach
())
self
.
anchor_queue
.
append
(
det_anchors
.
detach
())
self
.
period
+=
1
if
len
(
self
.
instance_feature_queue
)
>
self
.
queue_length
:
self
.
instance_feature_queue
.
pop
(
0
)
self
.
anchor_queue
.
pop
(
0
)
self
.
period
=
torch
.
clip
(
self
.
period
,
0
,
self
.
queue_length
)
def
prepare_planning
(
self
,
feature_maps
,
mask
,
batch_size
,
):
## ego instance init
feature_maps_inv
=
feature_maps_format
(
feature_maps
,
inverse
=
True
)
feature_map
=
feature_maps_inv
[
0
][
-
1
][:,
0
]
ego_feature
=
self
.
ego_feature_encoder
(
feature_map
)
ego_feature
=
ego_feature
.
unsqueeze
(
1
).
squeeze
(
-
1
).
squeeze
(
-
1
)
ego_anchor
=
torch
.
tile
(
self
.
ego_anchor
[
None
],
(
batch_size
,
1
,
1
)
)
if
self
.
prev_ego_status
is
not
None
:
prev_ego_status
=
torch
.
where
(
mask
[:,
None
,
None
],
self
.
prev_ego_status
,
self
.
prev_ego_status
.
new_tensor
(
0
),
)
ego_anchor
[...,
VY
]
=
prev_ego_status
[...,
6
]
if
self
.
ego_period
==
None
:
self
.
ego_period
=
ego_feature
.
new_zeros
((
batch_size
,
1
)).
long
()
else
:
self
.
ego_period
=
torch
.
where
(
mask
[:,
None
],
self
.
ego_period
,
self
.
ego_period
.
new_tensor
(
0
),
)
self
.
ego_feature_queue
.
append
(
ego_feature
.
detach
())
self
.
ego_anchor_queue
.
append
(
ego_anchor
.
detach
())
self
.
ego_period
+=
1
if
len
(
self
.
ego_feature_queue
)
>
self
.
queue_length
:
self
.
ego_feature_queue
.
pop
(
0
)
self
.
ego_anchor_queue
.
pop
(
0
)
self
.
ego_period
=
torch
.
clip
(
self
.
ego_period
,
0
,
self
.
queue_length
)
return
ego_feature
,
ego_anchor
def
cache_motion
(
self
,
instance_feature
,
det_output
,
metas
):
det_classification
=
det_output
[
"classification"
][
-
1
].
sigmoid
()
det_confidence
=
det_classification
.
max
(
dim
=-
1
).
values
instance_id
=
det_output
[
'instance_id'
]
self
.
metas
=
metas
self
.
prev_confidence
=
det_confidence
.
detach
()
self
.
prev_instance_id
=
instance_id
def
cache_planning
(
self
,
ego_feature
,
ego_status
):
self
.
prev_ego_status
=
ego_status
.
detach
()
self
.
ego_feature_queue
[
-
1
]
=
ego_feature
.
detach
()
projects/mmdet3d_plugin/models/motion/motion_blocks.py
0 → 100644
View file @
afe88104
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
mmcv.cnn
import
Linear
,
Scale
,
bias_init_with_prob
from
mmcv.runner.base_module
import
Sequential
,
BaseModule
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.registry
import
(
PLUGIN_LAYERS
,
)
from
projects.mmdet3d_plugin.core.box3d
import
*
from
..blocks
import
linear_relu_ln
@
PLUGIN_LAYERS
.
register_module
()
class
MotionPlanningRefinementModule
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
fut_ts
=
12
,
fut_mode
=
6
,
ego_fut_ts
=
6
,
ego_fut_mode
=
3
,
):
super
(
MotionPlanningRefinementModule
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
fut_ts
=
fut_ts
self
.
fut_mode
=
fut_mode
self
.
ego_fut_ts
=
ego_fut_ts
self
.
ego_fut_mode
=
ego_fut_mode
self
.
motion_cls_branch
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
embed_dims
,
1
),
)
self
.
motion_reg_branch
=
nn
.
Sequential
(
nn
.
Linear
(
embed_dims
,
embed_dims
),
nn
.
ReLU
(),
nn
.
Linear
(
embed_dims
,
embed_dims
),
nn
.
ReLU
(),
nn
.
Linear
(
embed_dims
,
fut_ts
*
2
),
)
self
.
plan_cls_branch
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
embed_dims
,
1
),
)
self
.
plan_reg_branch
=
nn
.
Sequential
(
nn
.
Linear
(
embed_dims
,
embed_dims
),
nn
.
ReLU
(),
nn
.
Linear
(
embed_dims
,
embed_dims
),
nn
.
ReLU
(),
nn
.
Linear
(
embed_dims
,
ego_fut_ts
*
2
),
)
self
.
plan_status_branch
=
nn
.
Sequential
(
nn
.
Linear
(
embed_dims
,
embed_dims
),
nn
.
ReLU
(),
nn
.
Linear
(
embed_dims
,
embed_dims
),
nn
.
ReLU
(),
nn
.
Linear
(
embed_dims
,
10
),
)
def
init_weight
(
self
):
bias_init
=
bias_init_with_prob
(
0.01
)
nn
.
init
.
constant_
(
self
.
motion_cls_branch
[
-
1
].
bias
,
bias_init
)
nn
.
init
.
constant_
(
self
.
plan_cls_branch
[
-
1
].
bias
,
bias_init
)
def
forward
(
self
,
motion_query
,
plan_query
,
ego_feature
,
ego_anchor_embed
,
):
bs
,
num_anchor
=
motion_query
.
shape
[:
2
]
motion_cls
=
self
.
motion_cls_branch
(
motion_query
).
squeeze
(
-
1
)
motion_reg
=
self
.
motion_reg_branch
(
motion_query
).
reshape
(
bs
,
num_anchor
,
self
.
fut_mode
,
self
.
fut_ts
,
2
)
plan_cls
=
self
.
plan_cls_branch
(
plan_query
).
squeeze
(
-
1
)
plan_reg
=
self
.
plan_reg_branch
(
plan_query
).
reshape
(
bs
,
1
,
3
*
self
.
ego_fut_mode
,
self
.
ego_fut_ts
,
2
)
planning_status
=
self
.
plan_status_branch
(
ego_feature
+
ego_anchor_embed
)
return
motion_cls
,
motion_reg
,
plan_cls
,
plan_reg
,
planning_status
\ No newline at end of file
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment