Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
2bb43004
Commit
2bb43004
authored
Apr 27, 2020
by
zhangwenwei
Browse files
Fix training bug
parent
4073acf7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
124 additions
and
23 deletions
+124
-23
mmdet3d/apis/train.py
mmdet3d/apis/train.py
+92
-0
mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py
mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py
+26
-17
mmdet3d/models/anchor_heads/train_mixins.py
mmdet3d/models/anchor_heads/train_mixins.py
+2
-2
mmdet3d/models/fusion_layers/point_fusion.py
mmdet3d/models/fusion_layers/point_fusion.py
+2
-3
tools/train.py
tools/train.py
+2
-1
No files found.
mmdet3d/apis/train.py
View file @
2bb43004
import
torch
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
DistSamplerSeedHook
,
Runner
from
mmdet.apis.train
import
parse_losses
from
mmdet.core
import
(
DistEvalHook
,
DistOptimizerHook
,
EvalHook
,
Fp16OptimizerHook
,
build_optimizer
)
from
mmdet.datasets
import
build_dataloader
,
build_dataset
from
mmdet.utils
import
get_root_logger
def
batch_processor
(
model
,
data
,
train_mode
):
...
...
@@ -27,3 +35,87 @@ def batch_processor(model, data, train_mode):
outputs
=
dict
(
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
num_samples
)
return
outputs
def
train_detector
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
logger
=
get_root_logger
(
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
data_loaders
=
[
build_dataloader
(
ds
,
cfg
.
data
.
samples_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
seed
=
cfg
.
seed
)
for
ds
in
dataset
]
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
.
cuda
(
cfg
.
gpu_ids
[
0
]),
device_ids
=
cfg
.
gpu_ids
)
# build runner
optimizer
=
build_optimizer
(
model
,
cfg
.
optimizer
)
runner
=
Runner
(
model
,
batch_processor
,
optimizer
,
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
)
# an ugly walkaround to make the .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# fp16 setting
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
optimizer_config
=
Fp16OptimizerHook
(
**
cfg
.
optimizer_config
,
**
fp16_cfg
,
distributed
=
distributed
)
elif
distributed
:
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
optimizer_config
=
cfg
.
optimizer_config
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
if
distributed
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# register eval hooks
if
validate
:
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
val_dataloader
=
build_dataloader
(
val_dataset
,
samples_per_gpu
=
1
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
dist
=
distributed
,
shuffle
=
False
)
eval_cfg
=
cfg
.
get
(
'evaluation'
,
{})
eval_hook
=
DistEvalHook
if
distributed
else
EvalHook
runner
.
register_hook
(
eval_hook
(
val_dataloader
,
**
eval_cfg
))
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
,
cfg
.
total_epochs
)
mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py
View file @
2bb43004
import
torch
from
mmdet3d.ops.iou3d
import
boxes_iou3d_gpu
from
mmdet.core.bbox
import
bbox_overlaps
from
mmdet.core.bbox.iou_calculators.builder
import
IOU_CALCULATORS
...
...
@@ -33,18 +35,21 @@ class BboxOverlaps3D(object):
def
bbox_overlaps_nearest_3d
(
bboxes1
,
bboxes2
,
mode
=
'iou'
,
is_aligned
=
False
):
'''
:param bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]?
:param bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]?
:param mode: mode (str): "iou" (intersection over union) or iof
"""Calculate nearest 3D IoU
Args:
bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]?
bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]?
mode: mode (str): "iou" (intersection over union) or iof
(intersection over foreground).
:return: iou: (M, N) not support aligned mode currently
rbboxes: [N, 5(x, y, xdim, ydim, rad)] rotated bboxes
'''
rbboxes1_bev
=
bboxes1
.
index_select
(
dim
=-
1
,
index
=
bboxes1
.
new_tensor
([
0
,
1
,
3
,
4
,
6
]).
long
())
rbboxes2_bev
=
bboxes2
.
index_select
(
dim
=-
1
,
index
=
bboxes1
.
new_tensor
([
0
,
1
,
3
,
4
,
6
]).
long
())
Return:
iou: (M, N) not support aligned mode currently
"""
assert
bboxes1
.
size
(
-
1
)
==
bboxes2
.
size
(
-
1
)
==
7
column_index1
=
bboxes1
.
new_tensor
([
0
,
1
,
3
,
4
,
6
],
dtype
=
torch
.
long
)
rbboxes1_bev
=
bboxes1
.
index_select
(
dim
=-
1
,
index
=
column_index1
)
rbboxes2_bev
=
bboxes2
.
index_select
(
dim
=-
1
,
index
=
column_index1
)
# Change the bboxes to bev
# box conversion and iou calculation in torch version on CUDA
...
...
@@ -57,14 +62,18 @@ def bbox_overlaps_nearest_3d(bboxes1, bboxes2, mode='iou', is_aligned=False):
def
bbox_overlaps_3d
(
bboxes1
,
bboxes2
,
mode
=
'iou'
):
'''
"""Calculate 3D IoU using cuda implementation
:param bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]
:param bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]
:param mode: mode (str): "iou" (intersection over union) or
Args:
bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]
bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]
mode: mode (str): "iou" (intersection over union) or
iof (intersection over foreground).
:return: iou: (M, N) not support aligned mode currently
'''
Return:
iou: (M, N) not support aligned mode currently
"""
# TODO: check the input dimension meanings,
# this is inconsistent with that in bbox_overlaps_nearest_3d
assert
bboxes1
.
size
(
-
1
)
==
bboxes2
.
size
(
-
1
)
==
7
return
boxes_iou3d_gpu
(
bboxes1
,
bboxes2
,
mode
)
mmdet3d/models/anchor_heads/train_mixins.py
View file @
2bb43004
...
...
@@ -176,10 +176,10 @@ class AnchorTrainMixin(object):
neg_inds
=
sampling_result
.
neg_inds
else
:
pos_inds
=
torch
.
nonzero
(
anchors
.
new_zeros
((
anchors
.
shape
[
0
],
),
dtype
=
torch
.
long
)
>
0
anchors
.
new_zeros
((
anchors
.
shape
[
0
],
),
dtype
=
torch
.
bool
)
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
anchors
.
new_zeros
((
anchors
.
shape
[
0
],
),
dtype
=
torch
.
long
)
==
anchors
.
new_zeros
((
anchors
.
shape
[
0
],
),
dtype
=
torch
.
bool
)
==
0
).
squeeze
(
-
1
).
unique
()
if
gt_labels
is
not
None
:
...
...
mmdet3d/models/fusion_layers/point_fusion.py
View file @
2bb43004
...
...
@@ -235,9 +235,8 @@ class PointFusion(nn.Module):
pts
.
new_tensor
(
img_meta
[
'pcd_trans'
])
if
'pcd_trans'
in
img_meta
.
keys
()
else
0
)
pcd_rotate_mat
=
(
pts
.
new_tensor
(
img_meta
[
'pcd_rotation'
])
if
'pcd_rotation'
in
img_meta
.
keys
()
else
torch
.
eye
(
3
).
type_as
(
pts
).
to
(
pts
.
device
))
pts
.
new_tensor
(
img_meta
[
'pcd_rotation'
])
if
'pcd_rotation'
in
img_meta
.
keys
()
else
torch
.
eye
(
3
).
type_as
(
pts
).
to
(
pts
.
device
))
img_scale_factor
=
(
img_meta
[
'scale_factor'
]
if
'scale_factor'
in
img_meta
.
keys
()
else
1
)
...
...
tools/train.py
View file @
2bb43004
...
...
@@ -11,10 +11,11 @@ from mmcv import Config
from
mmcv.runner
import
init_dist
from
mmdet3d
import
__version__
from
mmdet3d.apis
import
train_detector
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.models
import
build_detector
from
mmdet3d.utils
import
collect_env
from
mmdet.apis
import
get_root_logger
,
set_random_seed
,
train_detector
from
mmdet.apis
import
get_root_logger
,
set_random_seed
def
parse_args
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment