Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
96dfcbd0
Commit
96dfcbd0
authored
Jun 24, 2020
by
Gus-Guo
Browse files
add AxisAlignedTargetAssigner
parent
1059a42b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
187 additions
and
2 deletions
+187
-2
pcdet/models/dense_heads/anchor_head_template.py
pcdet/models/dense_heads/anchor_head_template.py
+3
-2
pcdet/models/dense_heads/target_assigner/axis_aligned_target_assigner.py
...nse_heads/target_assigner/axis_aligned_target_assigner.py
+184
-0
No files found.
pcdet/models/dense_heads/anchor_head_template.py
View file @
96dfcbd0
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.target_assigner.anchor_generator
import
AnchorGenerator
from
.target_assigner.anchor_generator
import
AnchorGenerator
from
.target_assigner.atss_target_assigner
import
ATSSTargetAssigner
from
.target_assigner.atss_target_assigner
import
ATSSTargetAssigner
from
.target_assigner.axis_aligned_target_assigner
import
AxisAlignedTargetAssigner
from
...utils
import
box_coder_utils
,
loss_utils
,
common_utils
from
...utils
import
box_coder_utils
,
loss_utils
,
common_utils
...
@@ -45,8 +46,8 @@ class AnchorHeadTemplate(nn.Module):
...
@@ -45,8 +46,8 @@ class AnchorHeadTemplate(nn.Module):
box_coder
=
self
.
box_coder
,
box_coder
=
self
.
box_coder
,
match_height
=
anchor_target_cfg
.
MATCH_HEIGHT
match_height
=
anchor_target_cfg
.
MATCH_HEIGHT
)
)
elif
anchor_target_cfg
.
NAME
==
'
Second
'
:
elif
anchor_target_cfg
.
NAME
==
'
AxisAlignedTargetAssigner
'
:
target_assigner
=
Secon
dTargetAssigner
(
target_assigner
=
AxisAligne
dTargetAssigner
(
anchor_target_cfg
=
anchor_target_cfg
,
anchor_target_cfg
=
anchor_target_cfg
,
box_coder
=
self
.
box_coder
,
box_coder
=
self
.
box_coder
,
match_height
=
anchor_target_cfg
.
MATCH_HEIGHT
match_height
=
anchor_target_cfg
.
MATCH_HEIGHT
...
...
pcdet/models/dense_heads/target_assigner/axis_aligned_target_assigner.py
0 → 100644
View file @
96dfcbd0
import
torch
from
....utils
import
box_utils
from
....ops.iou3d_nms
import
iou3d_nms_utils
class
AxisAlignedTargetAssigner
(
object
):
def
__init__
(
self
,
anchor_target_cfg
,
box_coder
,
match_height
=
False
):
super
().
__init__
()
self
.
box_coder
=
box_coder
self
.
match_height
=
match_height
self
.
pos_fraction
=
anchor_target_cfg
.
POS_FRACTION
if
anchor_target_cfg
.
POS_FRACTION
>=
0
else
None
self
.
sample_size
=
anchor_target_cfg
.
SAMPLE_SIZE
self
.
matched_thresholds
=
anchor_target_cfg
.
MATCHED_THRESHOLDS
self
.
unmatched_thresholds
=
anchor_target_cfg
.
UNMATCHED_THRESHOLDS
self
.
norm_by_num_examples
=
anchor_target_cfg
.
NORM_BY_NUM_EXAMPLES
def
assign_targets
(
self
,
all_anchors
,
gt_boxes_with_classes
,
use_multihead
=
False
):
"""
Args:
all_anchors: [(N, 7), ...]
gt_boxes: (B, M, 8)
Returns:
"""
bbox_targets
=
[]
bbox_src_targets
=
[]
cls_labels
=
[]
reg_weights
=
[]
batch_size
=
gt_boxes_with_classes
.
shape
[
0
]
gt_classes
=
gt_boxes_with_classes
[:,
:,
7
]
gt_boxes
=
gt_boxes_with_classes
[:,
:,
:
7
]
for
k
in
range
(
batch_size
):
cur_gt
=
gt_boxes
[
k
]
cnt
=
cur_gt
.
__len__
()
-
1
while
cnt
>
0
and
cur_gt
[
cnt
].
sum
()
==
0
:
cnt
-=
1
cur_gt
=
cur_gt
[:
cnt
+
1
]
cur_gt_classes
=
gt_classes
[
k
][:
cnt
+
1
].
int
()
target_list
=
[]
for
class_index
,
anchors
in
enumerate
(
all_anchors
):
mask
=
torch
.
tensor
([
c
==
class_index
+
1
for
c
in
cur_gt_classes
],
dtype
=
torch
.
bool
)
if
use_multihead
:
anchors
=
anchors
.
permute
(
3
,
4
,
0
,
1
,
2
,
5
).
contiguous
().
view
(
-
1
,
anchors
.
shape
[
-
1
])
else
:
feature_map_size
=
anchors
.
shape
[:
3
]
anchors
=
anchors
.
view
(
-
1
,
anchors
.
shape
[
-
1
])
single_target
=
self
.
assign_targets_single
(
anchors
,
cur_gt
[
mask
],
gt_classes
=
cur_gt_classes
[
mask
],
matched_threshold
=
self
.
matched_thresholds
[
class_index
],
unmatched_threshold
=
self
.
unmatched_thresholds
[
class_index
]
)
target_list
.
append
(
single_target
)
if
use_multihead
:
target_dict
=
{
'box_cls_labels'
:
[
t
[
'box_cls_labels'
].
view
(
-
1
)
for
t
in
target_list
],
'box_reg_targets'
:
[
t
[
'box_reg_targets'
].
view
(
-
1
,
self
.
box_coder
.
code_size
)
for
t
in
target_list
],
'reg_weights'
:
[
t
[
'reg_weights'
].
view
(
-
1
)
for
t
in
target_list
]
}
target_dict
[
'box_reg_targets'
]
=
torch
.
cat
(
target_dict
[
'box_reg_targets'
],
dim
=
0
)
target_dict
[
'box_cls_labels'
]
=
torch
.
cat
(
target_dict
[
'box_cls_labels'
],
dim
=
0
).
view
(
-
1
)
target_dict
[
'reg_weights'
]
=
torch
.
cat
(
target_dict
[
'reg_weights'
],
dim
=
0
).
view
(
-
1
)
else
:
target_dict
=
{
'box_cls_labels'
:
[
t
[
'box_cls_labels'
].
view
(
*
feature_map_size
,
-
1
)
for
t
in
target_list
],
'box_reg_targets'
:
[
t
[
'box_reg_targets'
].
view
(
*
feature_map_size
,
-
1
,
self
.
box_coder
.
code_size
)
for
t
in
target_list
],
'reg_weights'
:
[
t
[
'reg_weights'
].
view
(
*
feature_map_size
,
-
1
)
for
t
in
target_list
]
}
target_dict
[
'box_reg_targets'
]
=
torch
.
cat
(
target_dict
[
'box_reg_targets'
],
dim
=-
2
).
view
(
-
1
,
self
.
box_coder
.
code_size
)
target_dict
[
'box_cls_labels'
]
=
torch
.
cat
(
target_dict
[
'box_cls_labels'
],
dim
=-
1
).
view
(
-
1
)
target_dict
[
'reg_weights'
]
=
torch
.
cat
(
target_dict
[
'reg_weights'
],
dim
=-
1
).
view
(
-
1
)
bbox_targets
.
append
(
target_dict
[
'box_reg_targets'
])
cls_labels
.
append
(
target_dict
[
'box_cls_labels'
])
reg_weights
.
append
(
target_dict
[
'reg_weights'
])
bbox_targets
=
torch
.
stack
(
bbox_targets
,
dim
=
0
)
cls_labels
=
torch
.
stack
(
cls_labels
,
dim
=
0
)
reg_weights
=
torch
.
stack
(
reg_weights
,
dim
=
0
)
all_targets_dict
=
{
'box_cls_labels'
:
cls_labels
,
'box_reg_targets'
:
bbox_targets
,
'reg_weights'
:
reg_weights
}
return
all_targets_dict
def
assign_targets_single
(
self
,
anchors
,
gt_boxes
,
gt_classes
,
matched_threshold
=
0.6
,
unmatched_threshold
=
0.45
):
num_anchors
=
anchors
.
shape
[
0
]
num_gt
=
gt_boxes
.
shape
[
0
]
box_ndim
=
anchors
.
shape
[
1
]
labels
=
torch
.
ones
((
num_anchors
,),
dtype
=
torch
.
int32
,
device
=
anchors
.
device
)
*
-
1
gt_ids
=
torch
.
ones
((
num_anchors
,),
dtype
=
torch
.
int32
,
device
=
anchors
.
device
)
*
-
1
if
len
(
gt_boxes
)
>
0
and
anchors
.
shape
[
0
]
>
0
:
anchor_by_gt_overlap
=
iou3d_nms_utils
.
boxes_iou3d_gpu
(
anchors
,
gt_boxes
)
if
self
.
match_height
else
box_utils
.
boxes3d_nearest_bev_iou
(
anchors
,
gt_boxes
)
anchor_to_gt_argmax
=
anchor_by_gt_overlap
.
argmax
(
dim
=
1
)
anchor_to_gt_max
=
anchor_by_gt_overlap
[
torch
.
arange
(
num_anchors
),
anchor_to_gt_argmax
]
gt_to_anchor_argmax
=
anchor_by_gt_overlap
.
argmax
(
dim
=
0
)
gt_to_anchor_max
=
anchor_by_gt_overlap
[
gt_to_anchor_argmax
,
torch
.
arange
(
num_gt
)]
empty_gt_mask
=
gt_to_anchor_max
==
0
gt_to_anchor_max
[
empty_gt_mask
]
=
-
1
anchors_with_max_overlap
=
torch
.
nonzero
(
anchor_by_gt_overlap
==
gt_to_anchor_max
)[:,
0
]
gt_inds_force
=
anchor_to_gt_argmax
[
anchors_with_max_overlap
]
labels
[
anchors_with_max_overlap
]
=
gt_classes
[
gt_inds_force
]
gt_ids
[
anchors_with_max_overlap
]
=
gt_inds_force
.
int
()
pos_inds
=
anchor_to_gt_max
>=
matched_threshold
gt_inds_over_thresh
=
anchor_to_gt_argmax
[
pos_inds
]
labels
[
pos_inds
]
=
gt_classes
[
gt_inds_over_thresh
]
gt_ids
[
pos_inds
]
=
gt_inds_over_thresh
.
int
()
bg_inds
=
torch
.
nonzero
(
anchor_to_gt_max
<
unmatched_threshold
)[:,
0
]
else
:
bg_inds
=
torch
.
arange
(
num_anchors
)
fg_inds
=
torch
.
nonzero
(
labels
>
0
)[:,
0
]
if
self
.
pos_fraction
is
not
None
:
num_fg
=
int
(
self
.
pos_fraction
*
self
.
sample_size
)
if
len
(
fg_inds
)
>
num_fg
:
num_disabled
=
len
(
fg_inds
)
-
num_fg
disable_inds
=
torch
.
randperm
(
len
(
fg_inds
))[:
num_disabled
]
labels
[
disable_inds
]
=
-
1
fg_inds
=
torch
.
nonzero
(
labels
>
0
)[:,
0
]
num_bg
=
self
.
sample_size
-
(
labels
>
0
).
sum
()
if
len
(
bg_inds
)
>
num_bg
:
enable_inds
=
bg_inds
[
torch
.
randint
(
0
,
len
(
bg_inds
),
size
=
(
num_bg
,))]
labels
[
enable_inds
]
=
0
bg_inds
=
torch
.
nonzero
(
labels
==
0
)[:,
0
]
else
:
if
len
(
gt_boxes
)
==
0
or
anchors
.
shape
[
0
]
==
0
:
labels
[:]
=
0
else
:
labels
[
bg_inds
]
=
0
labels
[
anchors_with_max_overlap
]
=
gt_classes
[
gt_inds_force
]
bbox_targets
=
anchors
.
new_zeros
((
num_anchors
,
self
.
box_coder
.
code_size
))
if
len
(
gt_boxes
)
>
0
and
anchors
.
shape
[
0
]
>
0
:
fg_gt_boxes
=
gt_boxes
[
anchor_to_gt_argmax
[
fg_inds
],
:]
fg_anchors
=
anchors
[
fg_inds
,
:]
bbox_targets
[
fg_inds
,
:]
=
self
.
box_coder
.
encode_torch
(
fg_gt_boxes
,
fg_anchors
)
reg_weights
=
anchors
.
new_zeros
((
num_anchors
,))
if
self
.
norm_by_num_examples
:
num_examples
=
(
labels
>=
0
).
sum
()
num_examples
=
num_examples
if
num_examples
>
1.0
else
1.0
reg_weights
[
labels
>
0
]
=
1.0
/
num_examples
else
:
reg_weights
[
labels
>
0
]
=
1.0
ret_dict
=
{
'box_cls_labels'
:
labels
,
'box_reg_targets'
:
bbox_targets
,
'reg_weights'
:
reg_weights
,
}
return
ret_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment