Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
4040dbda
Commit
4040dbda
authored
Apr 27, 2020
by
zhangwenwei
Browse files
Refactor anchor generator and box coder
parent
148fea12
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
259 additions
and
186 deletions
+259
-186
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+1
-1
mmdet3d/datasets/nuscenes2d_dataset.py
mmdet3d/datasets/nuscenes2d_dataset.py
+0
-38
mmdet3d/datasets/nuscenes_dataset.py
mmdet3d/datasets/nuscenes_dataset.py
+1
-1
mmdet3d/datasets/pipelines/__init__.py
mmdet3d/datasets/pipelines/__init__.py
+8
-1
mmdet3d/datasets/pipelines/dbsampler.py
mmdet3d/datasets/pipelines/dbsampler.py
+1
-1
mmdet3d/datasets/pipelines/formating.py
mmdet3d/datasets/pipelines/formating.py
+1
-1
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+1
-1
mmdet3d/datasets/pipelines/train_aug.py
mmdet3d/datasets/pipelines/train_aug.py
+2
-2
mmdet3d/datasets/registry.py
mmdet3d/datasets/registry.py
+1
-1
mmdet3d/models/anchor_heads/boxvelo_head.py
mmdet3d/models/anchor_heads/boxvelo_head.py
+17
-21
mmdet3d/models/anchor_heads/second_head.py
mmdet3d/models/anchor_heads/second_head.py
+27
-55
mmdet3d/models/anchor_heads/train_mixins.py
mmdet3d/models/anchor_heads/train_mixins.py
+2
-3
mmdet3d/models/builder.py
mmdet3d/models/builder.py
+2
-3
mmdet3d/models/registry.py
mmdet3d/models/registry.py
+1
-1
mmdet3d/models/utils/__init__.py
mmdet3d/models/utils/__init__.py
+0
-3
mmdet3d/models/utils/weight_init.py
mmdet3d/models/utils/weight_init.py
+0
-46
mmdet3d/utils/__init__.py
mmdet3d/utils/__init__.py
+3
-2
tests/test_anchor.py
tests/test_anchor.py
+149
-0
tests/test_config.py
tests/test_config.py
+29
-1
tools/train.py
tools/train.py
+13
-4
No files found.
mmdet3d/datasets/kitti_dataset.py
View file @
4040dbda
...
...
@@ -8,8 +8,8 @@ import torch
import
torch.utils.data
as
torch_data
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets.pipelines
import
Compose
from
..core.bbox
import
box_np_ops
from
.pipelines
import
Compose
from
.utils
import
remove_dontcare
...
...
mmdet3d/datasets/nuscenes2d_dataset.py
deleted
100644 → 0
View file @
148fea12
from
pycocotools.coco
import
COCO
from
mmdet3d.core.evaluation.coco_utils
import
getImgIds
from
mmdet.datasets
import
DATASETS
,
CocoDataset
@
DATASETS
.
register_module
class
NuScenes2DDataset
(
CocoDataset
):
CLASSES
=
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'barrier'
)
def
load_annotations
(
self
,
ann_file
):
if
not
self
.
class_names
:
self
.
class_names
=
self
.
CLASSES
self
.
coco
=
COCO
(
ann_file
)
# send class_names into the get id
# in case we only need to train on several classes
# by default self.class_names = CLASSES
self
.
cat_ids
=
self
.
coco
.
getCatIds
(
catNms
=
self
.
class_names
)
self
.
cat2label
=
{
cat_id
:
i
# + 1 rm +1 here thus the 0-79 are fg, 80 is bg
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)
}
# send cat ids to the get img id
# in case we only need to train on several classes
if
len
(
self
.
cat_ids
)
<
len
(
self
.
CLASSES
):
self
.
img_ids
=
getImgIds
(
self
.
coco
,
catIds
=
self
.
cat_ids
)
else
:
self
.
img_ids
=
self
.
coco
.
getImgIds
()
img_infos
=
[]
for
i
in
self
.
img_ids
:
info
=
self
.
coco
.
loadImgs
([
i
])[
0
]
info
[
'filename'
]
=
info
[
'file_name'
]
img_infos
.
append
(
info
)
return
img_infos
mmdet3d/datasets/nuscenes_dataset.py
View file @
4040dbda
...
...
@@ -9,8 +9,8 @@ import torch.utils.data as torch_data
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets.pipelines
import
Compose
from
..core.bbox
import
box_np_ops
from
.pipelines
import
Compose
@
DATASETS
.
register_module
...
...
mmdet3d/datasets/pipelines/__init__.py
View file @
4040dbda
from
mmdet.datasets.pipelines
import
Compose
from
.dbsampler
import
DataBaseSampler
,
MMDataBaseSampler
from
.formating
import
DefaultFormatBundle
,
DefaultFormatBundle3D
from
.loading
import
LoadMultiViewImageFromFiles
,
LoadPointsFromFile
from
.train_aug
import
(
GlobalRotScale
,
ObjectNoise
,
ObjectRangeFilter
,
ObjectSample
,
PointShuffle
,
PointsRangeFilter
,
RandomFlip3D
)
__all__
=
[
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScale'
,
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'Collect3D'
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'Collect3D'
,
'Compose'
,
'LoadMultiViewImageFromFiles'
,
'LoadPointsFromFile'
,
'DefaultFormatBundle'
,
'DefaultFormatBundle3D'
,
'DataBaseSampler'
,
'MMDataBaseSampler'
]
mmdet3d/datasets/pipelines/dbsampler.py
View file @
4040dbda
...
...
@@ -68,7 +68,7 @@ class DataBaseSampler(object):
db_infos
=
pickle
.
load
(
f
)
# filter database infos
from
mmdet
3d
.apis
import
get_root_logger
from
mmdet.apis
import
get_root_logger
logger
=
get_root_logger
()
for
k
,
v
in
db_infos
.
items
():
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
)
...
...
mmdet3d/datasets/pipelines/formating.py
View file @
4040dbda
import
numpy
as
np
from
mmcv.parallel
import
DataContainer
as
DC
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
to_tensor
from
mmdet.datasets.registry
import
PIPELINES
PIPELINES
.
_module_dict
.
pop
(
'DefaultFormatBundle'
)
...
...
mmdet3d/datasets/pipelines/loading.py
View file @
4040dbda
...
...
@@ -3,7 +3,7 @@ import os.path as osp
import
mmcv
import
numpy
as
np
from
mmdet.datasets.
registry
import
PIPELINES
from
mmdet.datasets.
builder
import
PIPELINES
@
PIPELINES
.
register_module
...
...
mmdet3d/datasets/pipelines/train_aug.py
View file @
4040dbda
import
numpy
as
np
from
mmcv.utils
import
build_from_cfg
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet
3d.utils
import
build_from_cfg
from
mmdet
.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
RandomFlip
from
mmdet.datasets.registry
import
PIPELINES
from
..registry
import
OBJECTSAMPLERS
from
.data_augment_utils
import
noise_per_object_v3_
...
...
mmdet3d/datasets/registry.py
View file @
4040dbda
from
mm
det
.utils
import
Registry
from
mm
cv
.utils
import
Registry
OBJECTSAMPLERS
=
Registry
(
'Object sampler'
)
mmdet3d/models/anchor_heads/boxvelo_head.py
View file @
4040dbda
import
numpy
as
np
import
torch
from
mmcv.cnn
import
normal_init
from
mmcv.cnn
import
bias_init_with_prob
,
normal_init
from
mmdet3d.core
import
box_torch_ops
,
boxes3d_to_bev_torch_lidar
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.models
import
HEADS
from
..utils
import
bias_init_with_prob
from
.second_head
import
SECONDHead
...
...
@@ -31,25 +30,25 @@ class Anchor3DVeloHead(SECONDHead):
in_channels
,
train_cfg
,
test_cfg
,
cache_anchor
=
False
,
feat_channels
=
256
,
use_direction_classifier
=
True
,
encode_bg_as_zeros
=
False
,
box_code_size
=
9
,
anchor_generator
=
dict
(
type
=
'AnchorGeneratorRange'
,
),
anchor_range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
anchor_strides
=
[
2
],
anchor_sizes
=
[[
1.6
,
3.9
,
1.56
]],
anchor_rotations
=
[
0
,
1.57
],
anchor_custom_values
=
[
0
,
0
],
anchor_generator
=
dict
(
type
=
'Anchor3DRangeGenerator'
,
range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
strides
=
[
2
],
sizes
=
[[
1.6
,
3.9
,
1.56
]],
rotations
=
[
0
,
1.57
],
custom_values
=
[
0
,
0
],
reshape_out
=
True
,
),
assigner_per_size
=
False
,
assign_per_class
=
False
,
diff_rad_by_sin
=
True
,
dir_offset
=
0
,
dir_limit_offset
=
1
,
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
bbox_coder
=
dict
(
type
=
'Residual3DBoxCoder'
,
),
bbox_coder
=
dict
(
type
=
'DeltaXYZWLHRBBoxCoder'
),
loss_cls
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
True
,
...
...
@@ -58,14 +57,11 @@ class Anchor3DVeloHead(SECONDHead):
type
=
'SmoothL1Loss'
,
beta
=
1.0
/
9.0
,
loss_weight
=
2.0
),
loss_dir
=
dict
(
type
=
'CrossEntropyLoss'
,
loss_weight
=
0.2
)):
super
().
__init__
(
class_names
,
in_channels
,
train_cfg
,
test_cfg
,
cache_anchor
,
feat_channels
,
use_direction_classifier
,
feat_channels
,
use_direction_classifier
,
encode_bg_as_zeros
,
box_code_size
,
anchor_generator
,
anchor_range
,
anchor_strides
,
anchor_sizes
,
anchor_rotations
,
anchor_custom_values
,
assigner_per_size
,
assign_per_class
,
diff_rad_by_sin
,
dir_offset
,
dir_limit_offset
,
target_means
,
target_stds
,
bbox_coder
,
loss_cls
,
loss_bbox
,
loss_dir
)
dir_offset
,
dir_limit_offset
,
bbox_coder
,
loss_cls
,
loss_bbox
,
loss_dir
)
self
.
num_classes
=
num_classes
# build head layers & losses
if
not
self
.
use_sigmoid_cls
:
...
...
@@ -131,9 +127,9 @@ class Anchor3DVeloHead(SECONDHead):
scores
=
scores
[
topk_inds
,
:]
dir_cls_score
=
dir_cls_score
[
topk_inds
]
bboxes
=
self
.
bbox_coder
.
decode
_torch
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
)
bboxes
=
self
.
bbox_coder
.
decode
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
mlvl_dir_scores
.
append
(
dir_cls_score
)
...
...
mmdet3d/models/anchor_heads/second_head.py
View file @
4040dbda
from
__future__
import
division
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmcv.cnn
import
bias_init_with_prob
,
normal_init
from
mmdet3d.core
import
(
PseudoSampler
,
box_torch_ops
,
boxes3d_to_bev_torch_lidar
,
build_anchor_generator
,
...
...
@@ -12,7 +10,6 @@ from mmdet3d.core import (PseudoSampler, box_torch_ops,
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.models
import
HEADS
from
..builder
import
build_loss
from
..utils
import
bias_init_with_prob
from
.train_mixins
import
AnchorTrainMixin
...
...
@@ -37,25 +34,24 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
in_channels
,
train_cfg
,
test_cfg
,
cache_anchor
=
False
,
feat_channels
=
256
,
use_direction_classifier
=
True
,
encode_bg_as_zeros
=
False
,
box_code_size
=
7
,
anchor_generator
=
dict
(
type
=
'AnchorGeneratorRange'
),
anchor_range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
anchor_strides
=
[
2
],
anchor_sizes
=
[[
1.6
,
3.9
,
1.56
]],
anchor_rotations
=
[
0
,
1.57
],
anchor_custom_values
=
[],
anchor_generator
=
dict
(
type
=
'Anchor3DRangeGenerator'
,
range
=
[
0
,
-
39.68
,
-
1.78
,
69.12
,
39.68
,
-
1.78
],
strides
=
[
2
],
sizes
=
[[
1.6
,
3.9
,
1.56
]],
rotations
=
[
0
,
1.57
],
custom_values
=
[],
reshape_out
=
False
),
assigner_per_size
=
False
,
assign_per_class
=
False
,
diff_rad_by_sin
=
True
,
dir_offset
=
0
,
dir_limit_offset
=
1
,
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
bbox_coder
=
dict
(
type
=
'Residual3DBoxCoder'
),
bbox_coder
=
dict
(
type
=
'DeltaXYZWLHRBBoxCoder'
),
loss_cls
=
dict
(
type
=
'CrossEntropyLoss'
,
use_sigmoid
=
True
,
...
...
@@ -94,29 +90,9 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
]
# build anchor generator
self
.
anchor_range
=
anchor_range
self
.
anchor_rotations
=
anchor_rotations
self
.
anchor_strides
=
anchor_strides
self
.
anchor_sizes
=
anchor_sizes
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
anchor_generators
=
[]
self
.
anchor_generator
=
build_anchor_generator
(
anchor_generator
)
# In 3D detection, the anchor stride is connected with anchor size
self
.
num_anchors
=
(
len
(
self
.
anchor_rotations
)
*
len
(
self
.
anchor_sizes
))
# if len(self.anchor_sizes) != self.anchor_strides:
# # this means different anchor in the same anchor strides
# anchor_sizes = [self.anchor_sizes]
for
anchor_stride
in
self
.
anchor_strides
:
anchor_generator
.
update
(
anchor_ranges
=
anchor_range
,
sizes
=
self
.
anchor_sizes
,
stride
=
anchor_stride
,
rotations
=
anchor_rotations
,
custom_values
=
anchor_custom_values
,
cache_anchor
=
cache_anchor
)
self
.
anchor_generators
.
append
(
build_anchor_generator
(
anchor_generator
))
self
.
num_anchors
=
self
.
anchor_generator
.
num_base_anchors
self
.
_init_layers
()
self
.
use_sigmoid_cls
=
loss_cls
.
get
(
'use_sigmoid'
,
False
)
...
...
@@ -152,7 +128,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
get_anchors
(
self
,
featmap_sizes
,
input_metas
):
def
get_anchors
(
self
,
featmap_sizes
,
input_metas
,
device
=
'cuda'
):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
...
...
@@ -161,16 +137,10 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
tuple: anchors of each image, valid flags of each image
"""
num_imgs
=
len
(
input_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors
=
[]
for
i
in
range
(
num_levels
):
anchors
=
self
.
anchor_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
])
if
not
self
.
assigner_per_size
:
anchors
=
anchors
.
reshape
(
-
1
,
anchors
.
size
(
-
1
))
multi_level_anchors
.
append
(
anchors
)
multi_level_anchors
=
self
.
anchor_generator
.
grid_anchors
(
featmap_sizes
,
device
=
device
)
anchor_list
=
[
multi_level_anchors
for
_
in
range
(
num_imgs
)]
return
anchor_list
...
...
@@ -237,9 +207,10 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
input_metas
,
gt_bboxes_ignore
=
None
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
anchor_list
=
self
.
get_anchors
(
featmap_sizes
,
input_metas
)
assert
len
(
featmap_sizes
)
==
self
.
anchor_generator
.
num_levels
device
=
cls_scores
[
0
].
device
anchor_list
=
self
.
get_anchors
(
featmap_sizes
,
input_metas
,
device
=
device
)
label_channels
=
self
.
cls_out_channels
if
self
.
use_sigmoid_cls
else
1
cls_reg_targets
=
self
.
anchor_target_3d
(
anchor_list
,
...
...
@@ -288,12 +259,14 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
assert
len
(
cls_scores
)
==
len
(
dir_cls_preds
)
num_levels
=
len
(
cls_scores
)
featmap_sizes
=
[
cls_scores
[
i
].
shape
[
-
2
:]
for
i
in
range
(
num_levels
)]
device
=
cls_scores
[
0
].
device
mlvl_anchors
=
self
.
anchor_generators
.
grid_anchors
(
featmap_sizes
,
device
=
device
)
mlvl_anchors
=
[
self
.
anchor_generators
[
i
].
grid_anchors
(
cls_scores
[
i
].
size
()[
-
2
:]).
reshape
(
-
1
,
self
.
box_code_size
)
for
i
in
range
(
num_levels
)
anchor
.
reshape
(
-
1
,
self
.
box_code_size
)
for
anchor
in
mlvl_anchors
]
result_list
=
[]
for
img_id
in
range
(
len
(
input_metas
)):
cls_score_list
=
[
...
...
@@ -353,9 +326,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
bbox_pred
=
bbox_pred
[
thr_inds
]
scores
=
scores
[
thr_inds
]
dir_cls_scores
=
dir_cls_score
[
thr_inds
]
bboxes
=
self
.
bbox_coder
.
decode_torch
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
)
bboxes
=
self
.
bbox_coder
.
decode
(
anchors
,
bbox_pred
)
bboxes_for_nms
=
boxes3d_to_bev_torch_lidar
(
bboxes
)
mlvl_bboxes_for_nms
.
append
(
bboxes_for_nms
)
mlvl_bboxes
.
append
(
bboxes
)
...
...
@@ -383,6 +354,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
selected_scores
=
mlvl_scores
[
selected
]
selected_label_preds
=
mlvl_label_preds
[
selected
]
selected_dir_scores
=
mlvl_dir_scores
[
selected
]
# TODO: move dir_offset to box coder
dir_rot
=
box_torch_ops
.
limit_period
(
selected_bboxes
[...,
-
1
]
-
self
.
dir_offset
,
self
.
dir_limit_offset
,
np
.
pi
)
...
...
mmdet3d/models/anchor_heads/train_mixins.py
View file @
4040dbda
...
...
@@ -197,9 +197,8 @@ class AnchorTrainMixin(object):
if
gt_labels
is
not
None
:
labels
+=
num_classes
if
len
(
pos_inds
)
>
0
:
pos_bbox_targets
=
self
.
bbox_coder
.
encode_torch
(
sampling_result
.
pos_bboxes
,
sampling_result
.
pos_gt_bboxes
,
target_means
,
target_stds
)
pos_bbox_targets
=
self
.
bbox_coder
.
encode
(
sampling_result
.
pos_bboxes
,
sampling_result
.
pos_gt_bboxes
)
pos_dir_targets
=
get_direction_target
(
sampling_result
.
pos_bboxes
,
pos_bbox_targets
,
...
...
mmdet3d/models/builder.py
View file @
4040dbda
from
mmdet.models.builder
import
build
from
mmdet.models.registry
import
(
BACKBONES
,
DETECTORS
,
HEADS
,
LOSSES
,
NECKS
,
ROI_EXTRACTORS
,
SHARED_HEADS
)
from
mmdet.models.builder
import
(
BACKBONES
,
DETECTORS
,
HEADS
,
LOSSES
,
NECKS
,
ROI_EXTRACTORS
,
SHARED_HEADS
,
build
)
from
.registry
import
FUSION_LAYERS
,
MIDDLE_ENCODERS
,
VOXEL_ENCODERS
...
...
mmdet3d/models/registry.py
View file @
4040dbda
from
mm
det
.utils
import
Registry
from
mm
cv
.utils
import
Registry
VOXEL_ENCODERS
=
Registry
(
'voxel_encoder'
)
MIDDLE_ENCODERS
=
Registry
(
'middle_encoder'
)
...
...
mmdet3d/models/utils/__init__.py
deleted
100644 → 0
View file @
148fea12
from
mmdet.models.utils
import
ResLayer
,
bias_init_with_prob
__all__
=
[
'bias_init_with_prob'
,
'ResLayer'
]
mmdet3d/models/utils/weight_init.py
deleted
100644 → 0
View file @
148fea12
import
numpy
as
np
import
torch.nn
as
nn
def
xavier_init
(
module
,
gain
=
1
,
bias
=
0
,
distribution
=
'normal'
):
assert
distribution
in
[
'uniform'
,
'normal'
]
if
distribution
==
'uniform'
:
nn
.
init
.
xavier_uniform_
(
module
.
weight
,
gain
=
gain
)
else
:
nn
.
init
.
xavier_normal_
(
module
.
weight
,
gain
=
gain
)
if
hasattr
(
module
,
'bias'
):
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
normal_init
(
module
,
mean
=
0
,
std
=
1
,
bias
=
0
):
nn
.
init
.
normal_
(
module
.
weight
,
mean
,
std
)
if
hasattr
(
module
,
'bias'
):
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
uniform_init
(
module
,
a
=
0
,
b
=
1
,
bias
=
0
):
nn
.
init
.
uniform_
(
module
.
weight
,
a
,
b
)
if
hasattr
(
module
,
'bias'
):
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
kaiming_init
(
module
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
,
bias
=
0
,
distribution
=
'normal'
):
assert
distribution
in
[
'uniform'
,
'normal'
]
if
distribution
==
'uniform'
:
nn
.
init
.
kaiming_uniform_
(
module
.
weight
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
else
:
nn
.
init
.
kaiming_normal_
(
module
.
weight
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
if
hasattr
(
module
,
'bias'
):
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
bias_init_with_prob
(
prior_prob
):
""" initialize conv/fc bias value according to giving probablity"""
bias_init
=
float
(
-
np
.
log
((
1
-
prior_prob
)
/
prior_prob
))
return
bias_init
mmdet3d/utils/__init__.py
View file @
4040dbda
from
mmdet.utils
import
(
Registry
,
build_from_cfg
,
get_model_complexity_info
,
get_root_logger
,
print_log
)
from
mmcv.utils
import
Registry
,
build_from_cfg
from
mmdet.utils
import
get_model_complexity_info
,
get_root_logger
,
print_log
from
.collect_env
import
collect_env
__all__
=
[
...
...
tests/test_anchor.py
0 → 100644
View file @
4040dbda
"""
CommandLine:
pytest tests/test_anchor.py
xdoctest tests/test_anchor.py zero
"""
import
torch
def
test_aligned_anchor_generator
():
from
mmdet3d.core.anchor
import
build_anchor_generator
if
torch
.
cuda
.
is_available
():
device
=
'cuda'
else
:
device
=
'cpu'
anchor_generator_cfg
=
dict
(
type
=
'AlignedAnchor3DRangeGenerator'
,
ranges
=
[[
-
51.2
,
-
51.2
,
-
1.80
,
51.2
,
51.2
,
-
1.80
]],
strides
=
[
1
,
2
,
4
],
sizes
=
[
[
0.8660
,
2.5981
,
1.
],
# 1.5/sqrt(3)
[
0.5774
,
1.7321
,
1.
],
# 1/sqrt(3)
[
1.
,
1.
,
1.
],
[
0.4
,
0.4
,
1
],
],
custom_values
=
[
0
,
0
],
rotations
=
[
0
,
1.57
],
size_per_range
=
False
,
reshape_out
=
True
)
featmap_sizes
=
[(
256
,
256
),
(
128
,
128
),
(
64
,
64
)]
anchor_generator
=
build_anchor_generator
(
anchor_generator_cfg
)
assert
anchor_generator
.
num_base_anchors
==
8
# check base anchors
expected_grid_anchors
=
[
torch
.
tensor
([[
-
51.0000
,
-
51.0000
,
-
1.8000
,
0.8660
,
2.5981
,
1.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
51.0000
,
-
51.0000
,
-
1.8000
,
0.4000
,
0.4000
,
1.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
50.6000
,
-
51.0000
,
-
1.8000
,
0.4000
,
0.4000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
50.2000
,
-
51.0000
,
-
1.8000
,
1.0000
,
1.0000
,
1.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
49.8000
,
-
51.0000
,
-
1.8000
,
1.0000
,
1.0000
,
1.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
49.4000
,
-
51.0000
,
-
1.8000
,
0.5774
,
1.7321
,
1.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
49.0000
,
-
51.0000
,
-
1.8000
,
0.5774
,
1.7321
,
1.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
48.6000
,
-
51.0000
,
-
1.8000
,
0.8660
,
2.5981
,
1.0000
,
1.5700
,
0.0000
,
0.0000
]],
device
=
device
),
torch
.
tensor
([[
-
50.8000
,
-
50.8000
,
-
1.8000
,
1.7320
,
5.1962
,
2.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
50.8000
,
-
50.8000
,
-
1.8000
,
0.8000
,
0.8000
,
2.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
50.0000
,
-
50.8000
,
-
1.8000
,
0.8000
,
0.8000
,
2.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
49.2000
,
-
50.8000
,
-
1.8000
,
2.0000
,
2.0000
,
2.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
48.4000
,
-
50.8000
,
-
1.8000
,
2.0000
,
2.0000
,
2.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
47.6000
,
-
50.8000
,
-
1.8000
,
1.1548
,
3.4642
,
2.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
46.8000
,
-
50.8000
,
-
1.8000
,
1.1548
,
3.4642
,
2.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
46.0000
,
-
50.8000
,
-
1.8000
,
1.7320
,
5.1962
,
2.0000
,
1.5700
,
0.0000
,
0.0000
]],
device
=
device
),
torch
.
tensor
([[
-
50.4000
,
-
50.4000
,
-
1.8000
,
3.4640
,
10.3924
,
4.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
50.4000
,
-
50.4000
,
-
1.8000
,
1.6000
,
1.6000
,
4.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
48.8000
,
-
50.4000
,
-
1.8000
,
1.6000
,
1.6000
,
4.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
47.2000
,
-
50.4000
,
-
1.8000
,
4.0000
,
4.0000
,
4.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
45.6000
,
-
50.4000
,
-
1.8000
,
4.0000
,
4.0000
,
4.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
44.0000
,
-
50.4000
,
-
1.8000
,
2.3096
,
6.9284
,
4.0000
,
1.5700
,
0.0000
,
0.0000
],
[
-
42.4000
,
-
50.4000
,
-
1.8000
,
2.3096
,
6.9284
,
4.0000
,
0.0000
,
0.0000
,
0.0000
],
[
-
40.8000
,
-
50.4000
,
-
1.8000
,
3.4640
,
10.3924
,
4.0000
,
1.5700
,
0.0000
,
0.0000
]],
device
=
device
)
]
multi_level_anchors
=
anchor_generator
.
grid_anchors
(
featmap_sizes
,
device
=
device
)
expected_multi_level_shapes
=
[
torch
.
Size
([
524288
,
9
]),
torch
.
Size
([
131072
,
9
]),
torch
.
Size
([
32768
,
9
])
]
for
i
,
single_level_anchor
in
enumerate
(
multi_level_anchors
):
assert
single_level_anchor
.
shape
==
expected_multi_level_shapes
[
i
]
# set [:56:7] thus it could cover 8 (len(size) * len(rotations))
# anchors on 8 location
assert
single_level_anchor
[:
56
:
7
].
allclose
(
expected_grid_anchors
[
i
])
tests/test_config.py
View file @
4040dbda
...
...
@@ -70,6 +70,34 @@ def test_config_build_detector():
# _check_bbox_head(head_config, detector.bbox_head)
def
test_config_build_pipeline
():
"""
Test that all detection models defined in the configs can be initialized.
"""
from
mmcv
import
Config
from
mmdet3d.datasets.pipelines
import
Compose
config_dpath
=
_get_config_directory
()
print
(
'Found config_dpath = {!r}'
.
format
(
config_dpath
))
import
glob
config_fpaths
=
list
(
glob
.
glob
(
join
(
config_dpath
,
'**'
,
'*.py'
)))
config_fpaths
=
[
p
for
p
in
config_fpaths
if
p
.
find
(
'_base_'
)
==
-
1
]
config_names
=
[
relpath
(
p
,
config_dpath
)
for
p
in
config_fpaths
]
print
(
'Using {} config files'
.
format
(
len
(
config_names
)))
for
config_fname
in
config_names
:
config_fpath
=
join
(
config_dpath
,
config_fname
)
config_mod
=
Config
.
fromfile
(
config_fpath
)
# build train_pipeline
train_pipeline
=
Compose
(
config_mod
.
train_pipeline
)
test_pipeline
=
Compose
(
config_mod
.
test_pipeline
)
assert
train_pipeline
is
not
None
assert
test_pipeline
is
not
None
def
test_config_data_pipeline
():
"""
Test whether the data pipeline is valid and can process corner cases.
...
...
@@ -77,7 +105,7 @@ def test_config_data_pipeline():
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
from
mmcv
import
Config
from
mmdet.datasets.pipelines
import
Compose
from
mmdet
3d
.datasets.pipelines
import
Compose
import
numpy
as
np
config_dpath
=
_get_config_directory
()
...
...
tools/train.py
View file @
4040dbda
...
...
@@ -27,12 +27,18 @@ def parse_args():
'--validate'
,
action
=
'store_true'
,
help
=
'whether to evaluate the checkpoint during training'
)
parser
.
add_argument
(
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
default
=
1
,
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'ids of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
...
...
@@ -73,11 +79,14 @@ def main():
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
cfg
.
gpus
=
args
.
gpus
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
else
:
cfg
.
gpu_ids
=
range
(
1
)
if
args
.
gpus
is
None
else
range
(
args
.
gpus
)
if
args
.
autoscale_lr
:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
cfg
.
gpu
s
/
8
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu
_ids
)
/
8
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment