Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
e3cd3c1d
Commit
e3cd3c1d
authored
May 18, 2020
by
zhangwenwei
Browse files
Refactor dense heads
parent
8c5dd998
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
175 additions
and
189 deletions
+175
-189
mmdet3d/models/dense_heads/__init__.py
mmdet3d/models/dense_heads/__init__.py
+4
-0
mmdet3d/models/dense_heads/anchor3d_head.py
mmdet3d/models/dense_heads/anchor3d_head.py
+85
-89
mmdet3d/models/dense_heads/parta2_rpn_head.py
mmdet3d/models/dense_heads/parta2_rpn_head.py
+34
-63
mmdet3d/models/dense_heads/train_mixins.py
mmdet3d/models/dense_heads/train_mixins.py
+0
-0
mmdet3d/models/detectors/mvx_single_stage.py
mmdet3d/models/detectors/mvx_single_stage.py
+8
-3
mmdet3d/models/detectors/mvx_two_stage.py
mmdet3d/models/detectors/mvx_two_stage.py
+8
-3
mmdet3d/models/detectors/voxelnet.py
mmdet3d/models/detectors/voxelnet.py
+7
-3
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
+5
-12
mmdet3d/models/roi_heads/part_aggregation_roi_head.py
mmdet3d/models/roi_heads/part_aggregation_roi_head.py
+17
-9
requirements/runtime.txt
requirements/runtime.txt
+1
-1
tests/test_heads.py
tests/test_heads.py
+6
-6
No files found.
mmdet3d/models/dense_heads/__init__.py
0 → 100644
View file @
e3cd3c1d
from
.anchor3d_head
import
Anchor3DHead
from
.parta2_rpn_head
import
PartA2RPNHead
__all__
=
[
'Anchor3DHead'
,
'PartA2RPNHead'
]
mmdet3d/models/
anchor
_heads/
secon
d_head.py
→
mmdet3d/models/
dense
_heads/
anchor3
d_head.py
View file @
e3cd3c1d
...
@@ -3,30 +3,26 @@ import torch
...
@@ -3,30 +3,26 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
bias_init_with_prob
,
normal_init
from
mmcv.cnn
import
bias_init_with_prob
,
normal_init
from
mmdet3d.core
import
(
PseudoSampler
,
box_torch_ops
,
from
mmdet3d.core
import
(
PseudoSampler
,
box3d_multiclass_nms
,
box_torch_ops
,
boxes3d_to_bev_torch_lidar
,
build_anchor_generator
,
boxes3d_to_bev_torch_lidar
,
build_anchor_generator
,
build_assigner
,
build_bbox_coder
,
build_sampler
,
build_assigner
,
build_bbox_coder
,
build_sampler
,
multi_apply
)
multi_apply
)
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
..builder
import
build_loss
from
..builder
import
build_loss
from
.train_mixins
import
AnchorTrainMixin
from
.train_mixins
import
AnchorTrainMixin
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
SECON
DHead
(
nn
.
Module
,
AnchorTrainMixin
):
class
Anchor3
DHead
(
nn
.
Module
,
AnchorTrainMixin
):
"""Anchor
-based
head for
VoxelNet detectors
.
"""Anchor head for
SECOND/PointPillars/MVXNet/PartA2
.
Args:
Args:
class
_name (list[str]): name of classes (TODO: to be removed)
num_
class
es (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
in_channels (int): Number of channels in the input feature map.
train_cfg (dict): train configs
train_cfg (dict): train configs
test_cfg (dict): test configs
test_cfg (dict): test configs
feat_channels (int): Number of channels of the feature map.
feat_channels (int): Number of channels of the feature map.
use_direction_classifier (bool): Whether to add a direction classifier.
use_direction_classifier (bool): Whether to add a direction classifier.
encode_bg_as_zeros (bool): Whether to use sigmoid of softmax
(TODO: to be removed)
box_code_size (int): The size of box code.
anchor_generator(dict): Config dict of anchor generator.
anchor_generator(dict): Config dict of anchor generator.
assigner_per_size (bool): Whether to do assignment for each separate
assigner_per_size (bool): Whether to do assignment for each separate
anchor size.
anchor size.
...
@@ -41,17 +37,15 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -41,17 +37,15 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
loss_cls (dict): Config of classification loss.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
loss_bbox (dict): Config of localization loss.
loss_dir (dict): Config of direction classifier loss.
loss_dir (dict): Config of direction classifier loss.
"""
# noqa: W605
"""
def
__init__
(
self
,
def
__init__
(
self
,
class
_nam
e
,
num_
classe
s
,
in_channels
,
in_channels
,
train_cfg
,
train_cfg
,
test_cfg
,
test_cfg
,
feat_channels
=
256
,
feat_channels
=
256
,
use_direction_classifier
=
True
,
use_direction_classifier
=
True
,
encode_bg_as_zeros
=
False
,
box_code_size
=
7
,
anchor_generator
=
dict
(
anchor_generator
=
dict
(
type
=
'Anchor3DRangeGenerator'
,
type
=
'Anchor3DRangeGenerator'
,
range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
...
@@ -75,41 +69,28 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -75,41 +69,28 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
loss_weight
=
0.2
)):
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
loss_weight
=
0.2
)):
super
().
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
num_classes
=
len
(
class
_name
)
self
.
num_classes
=
num_
class
es
self
.
feat_channels
=
feat_channels
self
.
feat_channels
=
feat_channels
self
.
diff_rad_by_sin
=
diff_rad_by_sin
self
.
diff_rad_by_sin
=
diff_rad_by_sin
self
.
use_direction_classifier
=
use_direction_classifier
self
.
use_direction_classifier
=
use_direction_classifier
# self.encode_background_as_zeros = encode_bg_as_zeros
self
.
box_code_size
=
box_code_size
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
assigner_per_size
=
assigner_per_size
self
.
assigner_per_size
=
assigner_per_size
self
.
assign_per_class
=
assign_per_class
self
.
assign_per_class
=
assign_per_class
self
.
dir_offset
=
dir_offset
self
.
dir_offset
=
dir_offset
self
.
dir_limit_offset
=
dir_limit_offset
self
.
dir_limit_offset
=
dir_limit_offset
# build target assigner & sampler
if
train_cfg
is
not
None
:
self
.
sampling
=
loss_cls
[
'type'
]
not
in
[
'FocalLoss'
,
'GHMC'
]
if
self
.
sampling
:
self
.
bbox_sampler
=
build_sampler
(
train_cfg
.
sampler
)
else
:
self
.
bbox_sampler
=
PseudoSampler
()
if
isinstance
(
train_cfg
.
assigner
,
dict
):
self
.
bbox_assigner
=
build_assigner
(
train_cfg
.
assigner
)
elif
isinstance
(
train_cfg
.
assigner
,
list
):
self
.
bbox_assigner
=
[
build_assigner
(
res
)
for
res
in
train_cfg
.
assigner
]
# build anchor generator
# build anchor generator
self
.
anchor_generator
=
build_anchor_generator
(
anchor_generator
)
self
.
anchor_generator
=
build_anchor_generator
(
anchor_generator
)
# In 3D detection, the anchor stride is connected with anchor size
# In 3D detection, the anchor stride is connected with anchor size
self
.
num_anchors
=
self
.
anchor_generator
.
num_base_anchors
self
.
num_anchors
=
self
.
anchor_generator
.
num_base_anchors
# build box coder
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
box_code_size
=
self
.
bbox_coder
.
code_size
self
.
_init_layers
()
# build loss function
self
.
use_sigmoid_cls
=
loss_cls
.
get
(
'use_sigmoid'
,
False
)
self
.
use_sigmoid_cls
=
loss_cls
.
get
(
'use_sigmoid'
,
False
)
self
.
sampling
=
loss_cls
[
'type'
]
not
in
[
'FocalLoss'
,
'GHMC'
]
if
not
self
.
use_sigmoid_cls
:
if
not
self
.
use_sigmoid_cls
:
self
.
num_classes
+=
1
self
.
num_classes
+=
1
self
.
loss_cls
=
build_loss
(
loss_cls
)
self
.
loss_cls
=
build_loss
(
loss_cls
)
...
@@ -117,6 +98,24 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -117,6 +98,24 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
self
.
loss_dir
=
build_loss
(
loss_dir
)
self
.
loss_dir
=
build_loss
(
loss_dir
)
self
.
fp16_enabled
=
False
self
.
fp16_enabled
=
False
self
.
_init_layers
()
self
.
_init_assigner_sampler
()
def
_init_assigner_sampler
(
self
):
if
self
.
train_cfg
is
None
:
return
if
self
.
sampling
:
self
.
bbox_sampler
=
build_sampler
(
self
.
train_cfg
.
sampler
)
else
:
self
.
bbox_sampler
=
PseudoSampler
()
if
isinstance
(
self
.
train_cfg
.
assigner
,
dict
):
self
.
bbox_assigner
=
build_assigner
(
self
.
train_cfg
.
assigner
)
elif
isinstance
(
self
.
train_cfg
.
assigner
,
list
):
self
.
bbox_assigner
=
[
build_assigner
(
res
)
for
res
in
self
.
train_cfg
.
assigner
]
def
_init_layers
(
self
):
def
_init_layers
(
self
):
self
.
cls_out_channels
=
self
.
num_anchors
*
self
.
num_classes
self
.
cls_out_channels
=
self
.
num_anchors
*
self
.
num_classes
self
.
conv_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
cls_out_channels
,
1
)
self
.
conv_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
cls_out_channels
,
1
)
...
@@ -144,9 +143,12 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -144,9 +143,12 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
def
get_anchors
(
self
,
featmap_sizes
,
input_metas
,
device
=
'cuda'
):
def
get_anchors
(
self
,
featmap_sizes
,
input_metas
,
device
=
'cuda'
):
"""Get anchors according to feature map sizes.
"""Get anchors according to feature map sizes.
Args:
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
featmap_sizes (list[tuple]): Multi-level feature map sizes.
input_metas (list[dict]): contain pcd and img's meta info.
input_metas (list[dict]): contain pcd and img's meta info.
device (str): device of current module
Returns:
Returns:
tuple: anchors of each image, valid flags of each image
tuple: anchors of each image, valid flags of each image
"""
"""
...
@@ -204,12 +206,25 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -204,12 +206,25 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
@
staticmethod
@
staticmethod
def
add_sin_difference
(
boxes1
,
boxes2
):
def
add_sin_difference
(
boxes1
,
boxes2
):
rad_pred_encoding
=
torch
.
sin
(
boxes1
[...,
-
1
:])
*
torch
.
cos
(
"""Convert the rotation difference to difference in sine function
boxes2
[...,
-
1
:])
rad_tg_encoding
=
torch
.
cos
(
boxes1
[...,
-
1
:])
*
torch
.
sin
(
boxes2
[...,
Args:
-
1
:])
boxes1 (Tensor): shape (NxC), where C>=7 and the 7th dimension is
boxes1
=
torch
.
cat
([
boxes1
[...,
:
-
1
],
rad_pred_encoding
],
dim
=-
1
)
rotation dimension
boxes2
=
torch
.
cat
([
boxes2
[...,
:
-
1
],
rad_tg_encoding
],
dim
=-
1
)
boxes2 (Tensor): shape (NxC), where C>=7 and the 7th dimension is
rotation dimension
Returns:
tuple: (boxes1, boxes2) whose 7th dimensions are changed
"""
rad_pred_encoding
=
torch
.
sin
(
boxes1
[...,
6
:
7
])
*
torch
.
cos
(
boxes2
[...,
6
:
7
])
rad_tg_encoding
=
torch
.
cos
(
boxes1
[...,
6
:
7
])
*
torch
.
sin
(
boxes2
[...,
6
:
7
])
boxes1
=
torch
.
cat
(
[
boxes1
[...,
:
6
],
rad_pred_encoding
,
boxes1
[...,
7
:]],
dim
=-
1
)
boxes2
=
torch
.
cat
([
boxes2
[...,
:
6
],
rad_tg_encoding
,
boxes2
[...,
7
:]],
dim
=-
1
)
return
boxes1
,
boxes2
return
boxes1
,
boxes2
def
loss
(
self
,
def
loss
(
self
,
...
@@ -267,6 +282,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -267,6 +282,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
bbox_preds
,
bbox_preds
,
dir_cls_preds
,
dir_cls_preds
,
input_metas
,
input_metas
,
cfg
=
None
,
rescale
=
False
):
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
assert
len
(
cls_scores
)
==
len
(
dir_cls_preds
)
assert
len
(
cls_scores
)
==
len
(
dir_cls_preds
)
...
@@ -294,7 +310,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -294,7 +310,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
input_meta
=
input_metas
[
img_id
]
input_meta
=
input_metas
[
img_id
]
proposals
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
proposals
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
dir_cls_pred_list
,
mlvl_anchors
,
dir_cls_pred_list
,
mlvl_anchors
,
input_meta
,
rescale
)
input_meta
,
cfg
,
rescale
)
result_list
.
append
(
proposals
)
result_list
.
append
(
proposals
)
return
result_list
return
result_list
...
@@ -304,17 +320,19 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -304,17 +320,19 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
dir_cls_preds
,
dir_cls_preds
,
mlvl_anchors
,
mlvl_anchors
,
input_meta
,
input_meta
,
cfg
=
None
,
rescale
=
False
):
rescale
=
False
):
cfg
=
self
.
test_cfg
if
cfg
is
None
else
cfg
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
mlvl_bboxes
=
[]
mlvl_bboxes
=
[]
mlvl_scores
=
[]
mlvl_scores
=
[]
mlvl_dir_scores
=
[]
mlvl_dir_scores
=
[]
mlvl_bboxes_for_nms
=
[]
for
cls_score
,
bbox_pred
,
dir_cls_pred
,
anchors
in
zip
(
for
cls_score
,
bbox_pred
,
dir_cls_pred
,
anchors
in
zip
(
cls_scores
,
bbox_preds
,
dir_cls_preds
,
mlvl_anchors
):
cls_scores
,
bbox_preds
,
dir_cls_preds
,
mlvl_anchors
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
if
self
.
use_direction_classifier
:
assert
cls_score
.
size
()[
-
2
:]
==
dir_cls_pred
.
size
()[
-
2
:]
assert
cls_score
.
size
()[
-
2
:]
==
dir_cls_pred
.
size
()[
-
2
:]
dir_cls_pred
=
dir_cls_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
2
)
dir_cls_score
=
torch
.
max
(
dir_cls_pred
,
dim
=-
1
)[
1
]
cls_score
=
cls_score
.
permute
(
1
,
2
,
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
num_classes
)
0
).
reshape
(
-
1
,
self
.
num_classes
)
...
@@ -324,66 +342,44 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
...
@@ -324,66 +342,44 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
scores
=
cls_score
.
softmax
(
-
1
)
scores
=
cls_score
.
softmax
(
-
1
)
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
box_code_size
)
0
).
reshape
(
-
1
,
self
.
box_code_size
)
dir_cls_pred
=
dir_cls_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
2
)
dir_cls_score
=
torch
.
max
(
dir_cls_pred
,
dim
=-
1
)[
1
]
score_thr
=
self
.
test_cfg
.
get
(
'score_thr
'
,
0
)
nms_pre
=
cfg
.
get
(
'nms_pre
'
,
-
1
)
if
score_thr
>
0
:
if
nms_pre
>
0
and
scores
.
shape
[
0
]
>
nms_pre
:
if
self
.
use_sigmoid_cls
:
if
self
.
use_sigmoid_cls
:
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
else
:
else
:
max_scores
,
_
=
scores
[:,
1
:].
max
(
dim
=
1
)
max_scores
,
_
=
scores
[:,
:
-
1
].
max
(
dim
=
1
)
thr_inds
=
(
max_scores
>=
score_thr
)
_
,
topk_inds
=
max_scores
.
topk
(
nms_pre
)
anchors
=
anchors
[
thr_inds
]
anchors
=
anchors
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
thr_inds
]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
scores
=
scores
[
thr_inds
]
scores
=
scores
[
topk_inds
,
:]
dir_cls_scores
=
dir_cls_score
[
thr_inds
]
dir_cls_score
=
dir_cls_score
[
topk_inds
]
bboxes
=
self
.
bbox_coder
.
decode
(
anchors
,
bbox_pred
)
bboxes
=
self
.
bbox_coder
.
decode
(
anchors
,
bbox_pred
)
bboxes_for_nms
=
boxes3d_to_bev_torch_lidar
(
bboxes
)
mlvl_bboxes_for_nms
.
append
(
bboxes_for_nms
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
mlvl_scores
.
append
(
scores
)
mlvl_dir_scores
.
append
(
dir_cls_score
s
)
mlvl_dir_scores
.
append
(
dir_cls_score
)
mlvl_bboxes
=
torch
.
cat
(
mlvl_bboxes
)
mlvl_bboxes
=
torch
.
cat
(
mlvl_bboxes
)
mlvl_bboxes_for_nms
=
torch
.
cat
(
mlvl_bboxes
_for_nms
)
mlvl_bboxes_for_nms
=
boxes3d_to_bev_torch_lidar
(
mlvl_bboxes
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
mlvl_dir_scores
=
torch
.
cat
(
mlvl_dir_scores
)
mlvl_dir_scores
=
torch
.
cat
(
mlvl_dir_scores
)
if
len
(
mlvl_scores
)
>
0
:
if
self
.
use_sigmoid_cls
:
mlvl_scores
,
mlvl_label_preds
=
mlvl_scores
.
max
(
dim
=-
1
)
# Add a dummy background class to the front when using sigmoid
if
self
.
test_cfg
.
use_rotate_nms
:
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
nms_func
=
nms_gpu
mlvl_scores
=
torch
.
cat
([
mlvl_scores
,
padding
],
dim
=
1
)
else
:
nms_func
=
nms_normal_gpu
score_thr
=
cfg
.
get
(
'score_thr'
,
0
)
selected
=
nms_func
(
mlvl_bboxes_for_nms
,
mlvl_scores
,
results
=
box3d_multiclass_nms
(
mlvl_bboxes
,
mlvl_bboxes_for_nms
,
self
.
test_cfg
.
nms_thr
)
mlvl_scores
,
score_thr
,
cfg
.
max_num
,
else
:
cfg
,
mlvl_dir_scores
)
selected
=
[]
bboxes
,
scores
,
labels
,
dir_scores
=
results
if
bboxes
.
shape
[
0
]
>
0
:
if
len
(
selected
)
>
0
:
selected_bboxes
=
mlvl_bboxes
[
selected
]
selected_scores
=
mlvl_scores
[
selected
]
selected_label_preds
=
mlvl_label_preds
[
selected
]
selected_dir_scores
=
mlvl_dir_scores
[
selected
]
# TODO: move dir_offset to box coder
dir_rot
=
box_torch_ops
.
limit_period
(
dir_rot
=
box_torch_ops
.
limit_period
(
selected_bboxes
[...,
-
1
]
-
self
.
dir_offset
,
bboxes
[...,
6
]
-
self
.
dir_offset
,
self
.
dir_limit_offset
,
np
.
pi
)
self
.
dir_limit_offset
,
np
.
pi
)
bboxes
[...,
6
]
=
(
selected_bboxes
[...,
-
1
]
=
(
dir_rot
+
self
.
dir_offset
+
dir_rot
+
self
.
dir_offset
+
np
.
pi
*
selected_
dir_scores
.
to
(
selected_
bboxes
.
dtype
))
np
.
pi
*
dir_scores
.
to
(
bboxes
.
dtype
))
return
dict
(
return
bboxes
,
scores
,
labels
box3d_lidar
=
selected_bboxes
.
cpu
(),
scores
=
selected_scores
.
cpu
(),
label_preds
=
selected_label_preds
.
cpu
(),
sample_idx
=
input_meta
[
'sample_idx'
],
)
return
dict
(
box3d_lidar
=
mlvl_scores
.
new_zeros
([
0
,
7
]).
cpu
(),
scores
=
mlvl_scores
.
new_zeros
([
0
]).
cpu
(),
label_preds
=
mlvl_scores
.
new_zeros
([
0
,
4
]).
cpu
(),
sample_idx
=
input_meta
[
'sample_idx'
],
)
mmdet3d/models/
anchor
_heads/parta2_rpn_head.py
→
mmdet3d/models/
dense
_heads/parta2_rpn_head.py
View file @
e3cd3c1d
...
@@ -6,23 +6,31 @@ import torch
...
@@ -6,23 +6,31 @@ import torch
from
mmdet3d.core
import
box_torch_ops
,
boxes3d_to_bev_torch_lidar
from
mmdet3d.core
import
box_torch_ops
,
boxes3d_to_bev_torch_lidar
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
.
secon
d_head
import
SECON
DHead
from
.
anchor3
d_head
import
Anchor3
DHead
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
PartA2RPNHead
(
SECONDHead
):
class
PartA2RPNHead
(
Anchor3DHead
):
"""rpn head for PartA2
"""RPN head for PartA2
Note:
The main difference between the PartA2 RPN head and the Anchor3DHead
lies in their output during inference. PartA2 RPN head further returns
the original classification score for the second stage since the bbox
head in RoI head does not do classification task.
Different from RPN heads in 2D detectors, this RPN head does
multi-class classification task and uses FocalLoss like the SECOND and
PointPillars do. But this head uses class agnostic nms rather than
multi-class nms.
Args:
Args:
class
_name (list[str]): name of classes (TODO: to be removed)
num_
class
es (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
in_channels (int): Number of channels in the input feature map.
train_cfg (dict): train configs
train_cfg (dict): train configs
test_cfg (dict): test configs
test_cfg (dict): test configs
feat_channels (int): Number of channels of the feature map.
feat_channels (int): Number of channels of the feature map.
use_direction_classifier (bool): Whether to add a direction classifier.
use_direction_classifier (bool): Whether to add a direction classifier.
encode_bg_as_zeros (bool): Whether to use sigmoid of softmax
(TODO: to be removed)
box_code_size (int): The size of box code.
anchor_generator(dict): Config dict of anchor generator.
anchor_generator(dict): Config dict of anchor generator.
assigner_per_size (bool): Whether to do assignment for each separate
assigner_per_size (bool): Whether to do assignment for each separate
anchor size.
anchor size.
...
@@ -37,17 +45,15 @@ class PartA2RPNHead(SECONDHead):
...
@@ -37,17 +45,15 @@ class PartA2RPNHead(SECONDHead):
loss_cls (dict): Config of classification loss.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
loss_bbox (dict): Config of localization loss.
loss_dir (dict): Config of direction classifier loss.
loss_dir (dict): Config of direction classifier loss.
"""
# npqa:W293
"""
def
__init__
(
self
,
def
__init__
(
self
,
class
_nam
e
,
num_
classe
s
,
in_channels
,
in_channels
,
train_cfg
,
train_cfg
,
test_cfg
,
test_cfg
,
feat_channels
=
256
,
feat_channels
=
256
,
use_direction_classifier
=
True
,
use_direction_classifier
=
True
,
encode_bg_as_zeros
=
False
,
box_code_size
=
7
,
anchor_generator
=
dict
(
anchor_generator
=
dict
(
type
=
'Anchor3DRangeGenerator'
,
type
=
'Anchor3DRangeGenerator'
,
range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
...
@@ -69,49 +75,11 @@ class PartA2RPNHead(SECONDHead):
...
@@ -69,49 +75,11 @@ class PartA2RPNHead(SECONDHead):
loss_bbox
=
dict
(
loss_bbox
=
dict
(
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
2.0
),
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
2.0
),
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
loss_weight
=
0.2
)):
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
loss_weight
=
0.2
)):
super
().
__init__
(
class
_nam
e
,
in_channels
,
train_cfg
,
test_cfg
,
super
().
__init__
(
num_
classe
s
,
in_channels
,
train_cfg
,
test_cfg
,
feat_channels
,
use_direction_classifier
,
feat_channels
,
use_direction_classifier
,
encode_bg_as_zeros
,
box_code_size
,
anchor_generator
,
anchor_generator
,
assigner_per_size
,
assign_per_class
,
assigner_per_size
,
assign_per_class
,
diff_rad_by_sin
,
diff_rad_by_sin
,
dir_offset
,
dir_limit_offset
,
dir_offset
,
dir_limit_offset
,
bbox_coder
,
loss_cls
,
bbox_coder
,
loss_cls
,
loss_bbox
,
loss_dir
)
loss_bbox
,
loss_dir
)
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
dir_cls_preds
,
input_metas
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
assert
len
(
cls_scores
)
==
len
(
dir_cls_preds
)
num_levels
=
len
(
cls_scores
)
featmap_sizes
=
[
cls_scores
[
i
].
shape
[
-
2
:]
for
i
in
range
(
num_levels
)]
device
=
cls_scores
[
0
].
device
mlvl_anchors
=
self
.
anchor_generator
.
grid_anchors
(
featmap_sizes
,
device
=
device
)
mlvl_anchors
=
[
anchor
.
reshape
(
-
1
,
self
.
box_code_size
)
for
anchor
in
mlvl_anchors
]
result_list
=
[]
for
img_id
in
range
(
len
(
input_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
dir_cls_pred_list
=
[
dir_cls_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
input_meta
=
input_metas
[
img_id
]
proposals
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
dir_cls_pred_list
,
mlvl_anchors
,
input_meta
,
cfg
,
rescale
)
result_list
.
append
(
proposals
)
return
result_list
def
get_bboxes_single
(
self
,
def
get_bboxes_single
(
self
,
cls_scores
,
cls_scores
,
...
@@ -155,7 +123,7 @@ class PartA2RPNHead(SECONDHead):
...
@@ -155,7 +123,7 @@ class PartA2RPNHead(SECONDHead):
anchors
=
anchors
[
topk_inds
,
:]
anchors
=
anchors
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
max_scores
=
topk_scores
max_scores
=
topk_scores
cls_score
=
cls_
score
[
topk_inds
,
:]
cls_score
=
score
s
[
topk_inds
,
:]
dir_cls_score
=
dir_cls_score
[
topk_inds
]
dir_cls_score
=
dir_cls_score
[
topk_inds
]
pred_labels
=
pred_labels
[
topk_inds
]
pred_labels
=
pred_labels
[
topk_inds
]
...
@@ -171,8 +139,12 @@ class PartA2RPNHead(SECONDHead):
...
@@ -171,8 +139,12 @@ class PartA2RPNHead(SECONDHead):
mlvl_max_scores
=
torch
.
cat
(
mlvl_max_scores
)
mlvl_max_scores
=
torch
.
cat
(
mlvl_max_scores
)
mlvl_label_pred
=
torch
.
cat
(
mlvl_label_pred
)
mlvl_label_pred
=
torch
.
cat
(
mlvl_label_pred
)
mlvl_dir_scores
=
torch
.
cat
(
mlvl_dir_scores
)
mlvl_dir_scores
=
torch
.
cat
(
mlvl_dir_scores
)
mlvl_cls_score
=
torch
.
cat
(
# shape [k, num_class] before sigmoid
mlvl_cls_score
)
# shape [k, num_class] before sigmoid
# PartA2 need to keep raw classification score
# becase the bbox head in the second stage does not have
# classification branch,
# roi head need this score as classification score
mlvl_cls_score
=
torch
.
cat
(
mlvl_cls_score
)
score_thr
=
cfg
.
get
(
'score_thr'
,
0
)
score_thr
=
cfg
.
get
(
'score_thr'
,
0
)
result
=
self
.
class_agnostic_nms
(
mlvl_bboxes
,
mlvl_bboxes_for_nms
,
result
=
self
.
class_agnostic_nms
(
mlvl_bboxes
,
mlvl_bboxes_for_nms
,
...
@@ -180,7 +152,6 @@ class PartA2RPNHead(SECONDHead):
...
@@ -180,7 +152,6 @@ class PartA2RPNHead(SECONDHead):
mlvl_cls_score
,
mlvl_dir_scores
,
mlvl_cls_score
,
mlvl_dir_scores
,
score_thr
,
cfg
.
nms_post
,
cfg
)
score_thr
,
cfg
.
nms_post
,
cfg
)
result
.
update
(
dict
(
sample_idx
=
input_meta
[
'sample_idx'
]))
return
result
return
result
def
class_agnostic_nms
(
self
,
mlvl_bboxes
,
mlvl_bboxes_for_nms
,
def
class_agnostic_nms
(
self
,
mlvl_bboxes
,
mlvl_bboxes_for_nms
,
...
@@ -232,14 +203,14 @@ class PartA2RPNHead(SECONDHead):
...
@@ -232,14 +203,14 @@ class PartA2RPNHead(SECONDHead):
scores
=
scores
[
inds
]
scores
=
scores
[
inds
]
cls_scores
=
cls_scores
[
inds
]
cls_scores
=
cls_scores
[
inds
]
return
dict
(
return
dict
(
box
3d_lidar
=
bboxes
,
box
es_3d
=
bboxes
,
scores
=
scores
,
scores
_3d
=
scores
,
label
_preds
=
labels
,
label
s_3d
=
labels
,
cls_preds
=
cls_scores
# raw scores [max_num, cls_num]
cls_preds
=
cls_scores
# raw scores [max_num, cls_num]
)
)
else
:
else
:
return
dict
(
return
dict
(
box
3d_lidar
=
mlvl_bboxes
.
new_zeros
([
0
,
self
.
box_code_size
]),
box
es_3d
=
mlvl_bboxes
.
new_zeros
([
0
,
self
.
box_code_size
]),
scores
=
mlvl_bboxes
.
new_zeros
([
0
]),
scores
_3d
=
mlvl_bboxes
.
new_zeros
([
0
]),
label
_preds
=
mlvl_bboxes
.
new_zeros
([
0
]),
label
s_3d
=
mlvl_bboxes
.
new_zeros
([
0
]),
cls_preds
=
mlvl_bboxes
.
new_zeros
([
0
,
mlvl_cls_score
.
shape
[
-
1
]]))
cls_preds
=
mlvl_bboxes
.
new_zeros
([
0
,
mlvl_cls_score
.
shape
[
-
1
]]))
mmdet3d/models/
anchor
_heads/train_mixins.py
→
mmdet3d/models/
dense
_heads/train_mixins.py
View file @
e3cd3c1d
File moved
mmdet3d/models/detectors/mvx_single_stage.py
View file @
e3cd3c1d
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.ops
import
Voxelization
from
mmdet3d.ops
import
Voxelization
from
mmdet.models
import
DETECTORS
from
mmdet.models
import
DETECTORS
from
..
import
builder
from
..
import
builder
...
@@ -154,9 +155,13 @@ class MVXSingleStageDetector(BaseDetector):
...
@@ -154,9 +155,13 @@ class MVXSingleStageDetector(BaseDetector):
rescale
=
False
):
rescale
=
False
):
x
=
self
.
extract_feat
(
points
,
img
,
img_meta
)
x
=
self
.
extract_feat
(
points
,
img
,
img_meta
)
outs
=
self
.
pts_bbox_head
(
x
)
outs
=
self
.
pts_bbox_head
(
x
)
bbox_inputs
=
outs
+
(
img_meta
,
rescale
)
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
*
bbox_inputs
)
*
outs
,
img_meta
,
rescale
=
rescale
)
return
bbox_list
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
bbox_results
[
0
]
def
aug_test
(
self
,
points
,
imgs
,
img_metas
,
rescale
=
False
):
def
aug_test
(
self
,
points
,
imgs
,
img_metas
,
rescale
=
False
):
raise
NotImplementedError
raise
NotImplementedError
...
...
mmdet3d/models/detectors/mvx_two_stage.py
View file @
e3cd3c1d
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.ops
import
Voxelization
from
mmdet3d.ops
import
Voxelization
from
mmdet.models
import
DETECTORS
from
mmdet.models
import
DETECTORS
from
..
import
builder
from
..
import
builder
...
@@ -272,9 +273,13 @@ class MVXTwoStageDetector(BaseDetector):
...
@@ -272,9 +273,13 @@ class MVXTwoStageDetector(BaseDetector):
def
simple_test_pts
(
self
,
x
,
img_meta
,
rescale
=
False
):
def
simple_test_pts
(
self
,
x
,
img_meta
,
rescale
=
False
):
outs
=
self
.
pts_bbox_head
(
x
)
outs
=
self
.
pts_bbox_head
(
x
)
bbox_inputs
=
outs
+
(
img_meta
,
rescale
)
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
bbox_list
=
self
.
pts_bbox_head
.
get_bboxes
(
*
bbox_inputs
)
*
outs
,
img_meta
,
rescale
=
rescale
)
return
bbox_list
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
bbox_results
[
0
]
def
simple_test
(
self
,
def
simple_test
(
self
,
points
,
points
,
...
...
mmdet3d/models/detectors/voxelnet.py
View file @
e3cd3c1d
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmdet3d.core
import
bbox3d2result
from
mmdet3d.ops
import
Voxelization
from
mmdet3d.ops
import
Voxelization
from
mmdet.models
import
DETECTORS
,
SingleStageDetector
from
mmdet.models
import
DETECTORS
,
SingleStageDetector
from
..
import
builder
from
..
import
builder
...
@@ -83,9 +84,12 @@ class VoxelNet(SingleStageDetector):
...
@@ -83,9 +84,12 @@ class VoxelNet(SingleStageDetector):
def
simple_test
(
self
,
points
,
img_meta
,
gt_bboxes_3d
=
None
,
rescale
=
False
):
def
simple_test
(
self
,
points
,
img_meta
,
gt_bboxes_3d
=
None
,
rescale
=
False
):
x
=
self
.
extract_feat
(
points
,
img_meta
)
x
=
self
.
extract_feat
(
points
,
img_meta
)
outs
=
self
.
bbox_head
(
x
)
outs
=
self
.
bbox_head
(
x
)
bbox_inputs
=
outs
+
(
img_meta
,
rescale
)
bbox_list
=
self
.
bbox_head
.
get_bboxes
(
*
outs
,
img_meta
,
rescale
=
rescale
)
bbox_list
=
self
.
bbox_head
.
get_bboxes
(
*
bbox_inputs
)
bbox_results
=
[
return
bbox_list
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
bbox_results
[
0
]
@
DETECTORS
.
register_module
()
@
DETECTORS
.
register_module
()
...
...
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
View file @
e3cd3c1d
...
@@ -14,7 +14,7 @@ from mmdet.models import HEADS
...
@@ -14,7 +14,7 @@ from mmdet.models import HEADS
@
HEADS
.
register_module
()
@
HEADS
.
register_module
()
class
PartA2BboxHead
(
nn
.
Module
):
class
PartA2BboxHead
(
nn
.
Module
):
"""PartA2
rcnn box
head.
"""PartA2
RoI
head.
Args:
Args:
num_classes (int): The number of classes to prediction.
num_classes (int): The number of classes to prediction.
...
@@ -533,17 +533,10 @@ class PartA2BboxHead(nn.Module):
...
@@ -533,17 +533,10 @@ class PartA2BboxHead(nn.Module):
cfg
.
use_rotate_nms
)
cfg
.
use_rotate_nms
)
selected_bboxes
=
cur_rcnn_boxes3d
[
selected
]
selected_bboxes
=
cur_rcnn_boxes3d
[
selected
]
selected_label_preds
=
cur_class_labels
[
selected
]
selected_label_preds
=
cur_class_labels
[
selected
]
if
cfg
.
use_raw_score
:
selected_scores
=
cur_cls_score
[
selected
]
selected_scores
=
cur_cls_score
[
selected
]
else
:
selected_scores
=
torch
.
sigmoid
(
cur_cls_score
)[
selected
]
result_list
.
append
(
(
selected_bboxes
,
selected_scores
,
selected_label_preds
))
cur_result
=
dict
(
box3d_lidar
=
selected_bboxes
.
cpu
(),
scores
=
selected_scores
.
cpu
(),
label_preds
=
selected_label_preds
.
cpu
(),
sample_idx
=
img_meta
[
batch_id
][
'sample_idx'
])
result_list
.
append
(
cur_result
)
return
result_list
return
result_list
def
multi_class_nms
(
self
,
def
multi_class_nms
(
self
,
...
...
mmdet3d/models/roi_heads/part_aggregation_roi_head.py
View file @
e3cd3c1d
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmdet3d.core
import
AssignResult
from
mmdet3d.core
import
AssignResult
from
mmdet3d.core.bbox
import
bbox3d2roi
from
mmdet3d.core.bbox
import
bbox3d2result
,
bbox3d2roi
from
mmdet.core
import
build_assigner
,
build_sampler
from
mmdet.core
import
build_assigner
,
build_sampler
from
mmdet.models
import
HEADS
from
mmdet.models
import
HEADS
from
..builder
import
build_head
,
build_roi_extractor
from
..builder
import
build_head
,
build_roi_extractor
...
@@ -95,6 +95,9 @@ class PartAggregationROIHead(Base3DRoIHead):
...
@@ -95,6 +95,9 @@ class PartAggregationROIHead(Base3DRoIHead):
**
kwargs
):
**
kwargs
):
"""Simple testing forward function of PartAggregationROIHead
"""Simple testing forward function of PartAggregationROIHead
Note:
This function assumes that the batch size is 1
Args:
Args:
feats_dict (dict): Contains features from the first stage.
feats_dict (dict): Contains features from the first stage.
voxels_dict (dict): Contains information of voxels.
voxels_dict (dict): Contains information of voxels.
...
@@ -102,15 +105,15 @@ class PartAggregationROIHead(Base3DRoIHead):
...
@@ -102,15 +105,15 @@ class PartAggregationROIHead(Base3DRoIHead):
proposal_list (list[dict]): Proposal information from rpn.
proposal_list (list[dict]): Proposal information from rpn.
Returns:
Returns:
list[
dict
]
: Bbox results of
each batch
.
dict: Bbox results of
one frame
.
"""
"""
assert
self
.
with_bbox
,
'Bbox head must be implemented.'
assert
self
.
with_bbox
,
'Bbox head must be implemented.'
assert
self
.
with_semantic
assert
self
.
with_semantic
semantic_results
=
self
.
semantic_head
(
feats_dict
[
'seg_features'
])
semantic_results
=
self
.
semantic_head
(
feats_dict
[
'seg_features'
])
rois
=
bbox3d2roi
([
res
[
'box
3d_lidar
'
]
for
res
in
proposal_list
])
rois
=
bbox3d2roi
([
res
[
'box
es_3d
'
]
for
res
in
proposal_list
])
label
_preds
=
[
res
[
'label
_preds
'
]
for
res
in
proposal_list
]
label
s_3d
=
[
res
[
'label
s_3d
'
]
for
res
in
proposal_list
]
cls_preds
=
[
res
[
'cls_preds'
]
for
res
in
proposal_list
]
cls_preds
=
[
res
[
'cls_preds'
]
for
res
in
proposal_list
]
bbox_results
=
self
.
_bbox_forward
(
feats_dict
[
'seg_features'
],
bbox_results
=
self
.
_bbox_forward
(
feats_dict
[
'seg_features'
],
semantic_results
[
'part_feats'
],
semantic_results
[
'part_feats'
],
...
@@ -120,11 +123,16 @@ class PartAggregationROIHead(Base3DRoIHead):
...
@@ -120,11 +123,16 @@ class PartAggregationROIHead(Base3DRoIHead):
rois
,
rois
,
bbox_results
[
'cls_score'
],
bbox_results
[
'cls_score'
],
bbox_results
[
'bbox_pred'
],
bbox_results
[
'bbox_pred'
],
label
_preds
,
label
s_3d
,
cls_preds
,
cls_preds
,
img_meta
,
img_meta
,
cfg
=
self
.
test_cfg
)
cfg
=
self
.
test_cfg
)
return
bbox_list
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
bbox_results
[
0
]
def
_bbox_forward_train
(
self
,
seg_feats
,
part_feats
,
voxels_dict
,
def
_bbox_forward_train
(
self
,
seg_feats
,
part_feats
,
voxels_dict
,
sampling_results
):
sampling_results
):
...
@@ -164,8 +172,8 @@ class PartAggregationROIHead(Base3DRoIHead):
...
@@ -164,8 +172,8 @@ class PartAggregationROIHead(Base3DRoIHead):
# bbox assign
# bbox assign
for
batch_idx
in
range
(
len
(
proposal_list
)):
for
batch_idx
in
range
(
len
(
proposal_list
)):
cur_proposal_list
=
proposal_list
[
batch_idx
]
cur_proposal_list
=
proposal_list
[
batch_idx
]
cur_boxes
=
cur_proposal_list
[
'box
3d_lidar
'
]
cur_boxes
=
cur_proposal_list
[
'box
es_3d
'
]
cur_label
_preds
=
cur_proposal_list
[
'label
_preds
'
]
cur_label
s_3d
=
cur_proposal_list
[
'label
s_3d
'
]
cur_gt_bboxes
=
gt_bboxes_3d
[
batch_idx
]
cur_gt_bboxes
=
gt_bboxes_3d
[
batch_idx
]
cur_gt_labels
=
gt_labels_3d
[
batch_idx
]
cur_gt_labels
=
gt_labels_3d
[
batch_idx
]
...
@@ -178,7 +186,7 @@ class PartAggregationROIHead(Base3DRoIHead):
...
@@ -178,7 +186,7 @@ class PartAggregationROIHead(Base3DRoIHead):
if
isinstance
(
self
.
bbox_assigner
,
list
):
# for multi classes
if
isinstance
(
self
.
bbox_assigner
,
list
):
# for multi classes
for
i
,
assigner
in
enumerate
(
self
.
bbox_assigner
):
for
i
,
assigner
in
enumerate
(
self
.
bbox_assigner
):
gt_per_cls
=
(
cur_gt_labels
==
i
)
gt_per_cls
=
(
cur_gt_labels
==
i
)
pred_per_cls
=
(
cur_label
_preds
==
i
)
pred_per_cls
=
(
cur_label
s_3d
==
i
)
cur_assign_res
=
assigner
.
assign
(
cur_assign_res
=
assigner
.
assign
(
cur_boxes
[
pred_per_cls
],
cur_boxes
[
pred_per_cls
],
cur_gt_bboxes
[
gt_per_cls
],
cur_gt_bboxes
[
gt_per_cls
],
...
...
requirements/runtime.txt
View file @
e3cd3c1d
matplotlib
matplotlib
mmcv>=0.5.1
mmcv>=0.5.1
numba==0.4
5.1
numba==0.4
8.0
numpy
numpy
# need older pillow until torchvision is fixed
# need older pillow until torchvision is fixed
Pillow<=6.2.2
Pillow<=6.2.2
...
...
tests/test_heads.py
View file @
e3cd3c1d
...
@@ -65,7 +65,7 @@ def _get_rpn_head_cfg(fname):
...
@@ -65,7 +65,7 @@ def _get_rpn_head_cfg(fname):
return
rpn_head
,
train_cfg
.
rpn_proposal
return
rpn_head
,
train_cfg
.
rpn_proposal
def
test_
secon
d_head_loss
():
def
test_
anchor3
d_head_loss
():
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
bbox_head_cfg
=
_get_head_cfg
(
bbox_head_cfg
=
_get_head_cfg
(
...
@@ -117,7 +117,7 @@ def test_second_head_loss():
...
@@ -117,7 +117,7 @@ def test_second_head_loss():
assert
empty_gt_losses
[
'loss_rpn_dir'
][
0
]
==
0
assert
empty_gt_losses
[
'loss_rpn_dir'
][
0
]
==
0
def
test_
secon
d_head_getboxes
():
def
test_
anchor3
d_head_getboxes
():
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
bbox_head_cfg
=
_get_head_cfg
(
bbox_head_cfg
=
_get_head_cfg
(
...
@@ -140,7 +140,7 @@ def test_second_head_getboxes():
...
@@ -140,7 +140,7 @@ def test_second_head_getboxes():
cls_score
[
0
]
-=
1.5
# too many positive samples may cause cuda oom
cls_score
[
0
]
-=
1.5
# too many positive samples may cause cuda oom
result_list
=
self
.
get_bboxes
(
cls_score
,
bbox_pred
,
dir_cls_preds
,
result_list
=
self
.
get_bboxes
(
cls_score
,
bbox_pred
,
dir_cls_preds
,
input_metas
)
input_metas
)
assert
(
result_list
[
0
][
'scores'
]
>
0.3
).
all
()
assert
(
result_list
[
0
][
1
]
>
0.3
).
all
()
def
test_parta2_rpnhead_getboxes
():
def
test_parta2_rpnhead_getboxes
():
...
@@ -166,7 +166,7 @@ def test_parta2_rpnhead_getboxes():
...
@@ -166,7 +166,7 @@ def test_parta2_rpnhead_getboxes():
cls_score
[
0
]
-=
1.5
# too many positive samples may cause cuda oom
cls_score
[
0
]
-=
1.5
# too many positive samples may cause cuda oom
result_list
=
self
.
get_bboxes
(
cls_score
,
bbox_pred
,
dir_cls_preds
,
result_list
=
self
.
get_bboxes
(
cls_score
,
bbox_pred
,
dir_cls_preds
,
input_metas
,
proposal_cfg
)
input_metas
,
proposal_cfg
)
assert
result_list
[
0
][
'scores'
].
shape
==
torch
.
Size
([
512
])
assert
result_list
[
0
][
'scores
_3d
'
].
shape
==
torch
.
Size
([
512
])
assert
result_list
[
0
][
'label
_preds
'
].
shape
==
torch
.
Size
([
512
])
assert
result_list
[
0
][
'label
s_3d
'
].
shape
==
torch
.
Size
([
512
])
assert
result_list
[
0
][
'cls_preds'
].
shape
==
torch
.
Size
([
512
,
3
])
assert
result_list
[
0
][
'cls_preds'
].
shape
==
torch
.
Size
([
512
,
3
])
assert
result_list
[
0
][
'box
3d_lidar
'
].
shape
==
torch
.
Size
([
512
,
7
])
assert
result_list
[
0
][
'box
es_3d
'
].
shape
==
torch
.
Size
([
512
,
7
])
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment