Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
bce7d0c3
Commit
bce7d0c3
authored
May 12, 2020
by
yinchimaoliang
Browse files
Merge branch 'master_temp' into indoor_pipeline
parents
1756485e
868c5fab
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
tools/train.py
tools/train.py
+17
-13
No files found.
tools/train.py
View file @
bce7d0c3
...
...
@@ -8,7 +8,7 @@ import time
import
mmcv
import
torch
from
mmcv
import
Config
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
init_dist
from
mmdet3d
import
__version__
...
...
@@ -26,9 +26,9 @@ def parse_args():
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--validate'
,
'--
no-
validate'
,
action
=
'store_true'
,
help
=
'whether to evaluate the checkpoint during training'
)
help
=
'whether
not
to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
...
...
@@ -46,6 +46,8 @@ def parse_args():
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'arguments in dict'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
...
...
@@ -67,6 +69,9 @@ def main():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
options
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
...
...
@@ -101,7 +106,7 @@ def main():
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
'{
}.log'
.
format
(
timestamp
)
)
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
)
# add a logging filter
...
...
@@ -113,28 +118,27 @@ def main():
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
'{}: {}'
.
format
(
k
,
v
))
for
k
,
v
in
env_info_dict
.
items
()])
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
# log some basic info
logger
.
info
(
'Distributed training: {
}'
.
format
(
distributed
)
)
logger
.
info
(
'Config:
\n
{
}'
.
format
(
cfg
.
text
)
)
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_
text
}
'
)
# set random seeds
if
args
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {
}, deterministic: {}'
.
format
(
args
.
seed
,
args
.
deterministic
)
)
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
logger
.
info
(
'Model:
\n
{
}'
.
format
(
model
)
)
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
...
...
@@ -145,7 +149,7 @@ def main():
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
,
config
=
cfg
.
pretty_
text
,
CLASSES
=
datasets
[
0
].
CLASSES
)
# add an attribute for visualization convenience
model
.
CLASSES
=
datasets
[
0
].
CLASSES
...
...
@@ -154,7 +158,7 @@ def main():
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
args
.
validate
,
validate
=
(
not
args
.
no_
validate
)
,
timestamp
=
timestamp
,
meta
=
meta
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment