Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
a34823dc
Commit
a34823dc
authored
May 17, 2022
by
VVsssssk
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor]fix train.py and test.py
parent
5c5e459b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
73 additions
and
421 deletions
+73
-421
mmdet3d/utils/__init__.py
mmdet3d/utils/__init__.py
+2
-4
mmdet3d/utils/setup_env.py
mmdet3d/utils/setup_env.py
+20
-0
tests/test_utils/test_setup_env.py
tests/test_utils/test_setup_env.py
+15
-1
tools/test.py
tools/test.py
+26
-211
tools/train.py
tools/train.py
+10
-205
No files found.
mmdet3d/utils/__init__.py
View file @
a34823dc
...
...
@@ -4,11 +4,9 @@ from mmcv.utils import Registry, build_from_cfg, print_log
from
.collect_env
import
collect_env
from
.compat_cfg
import
compat_cfg
from
.logger
import
get_root_logger
from
.misc
import
find_latest_checkpoint
from
.setup_env
import
setup_multi_processes
from
.setup_env
import
register_all_modules
,
setup_multi_processes
__all__
=
[
'Registry'
,
'build_from_cfg'
,
'get_root_logger'
,
'collect_env'
,
'print_log'
,
'setup_multi_processes'
,
'find_latest_checkpoint'
,
'compat_cfg'
'print_log'
,
'setup_multi_processes'
,
'compat_cfg'
,
'register_all_modules'
]
mmdet3d/utils/setup_env.py
View file @
a34823dc
...
...
@@ -4,6 +4,7 @@ import platform
import
warnings
import
cv2
from
mmengine
import
DefaultScope
from
torch
import
multiprocessing
as
mp
...
...
@@ -51,3 +52,22 @@ def setup_multi_processes(cfg):
f
'overloaded, please further tune the variable for optimal '
f
'performance in your application as needed.'
)
os
.
environ
[
'MKL_NUM_THREADS'
]
=
str
(
mkl_num_threads
)
def
register_all_modules
(
init_default_scope
:
bool
=
True
)
->
None
:
"""Register all modules in mmdet3d into the registries.
Args:
init_default_scope (bool): Whether initialize the mmdet3d default scope.
When `init_default_scope=True`, the global default scope will be
set to `mmdet3d`, and all registries will build modules from mmdet3d's
registry node. To understand more about the registry, please refer
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
"""
# noqa
import
mmdet3d.core
# noqa: F401,F403
import
mmdet3d.datasets
# noqa: F401,F403
import
mmdet3d.models
# noqa: F401,F403
import
mmdet3d.ops
# noqa: F401,F403
if
init_default_scope
:
DefaultScope
.
get_instance
(
'mmdet3d'
,
scope_name
=
'mmdet3d'
)
tests/test_utils/test_setup_env.py
View file @
a34823dc
...
...
@@ -2,11 +2,25 @@
import
multiprocessing
as
mp
import
os
import
platform
import
sys
import
cv2
from
mmcv
import
Config
from
mmengine
import
DefaultScope
from
mmdet3d.utils
import
setup_multi_processes
from
mmdet3d.utils
import
register_all_modules
,
setup_multi_processes
def
test_register_all_modules
():
from
mmdet3d.registry
import
TRANSFORMS
sys
.
modules
.
pop
(
'mmdet3d.datasets'
,
None
)
sys
.
modules
.
pop
(
'mmdet3d.datasets.pipelines'
,
None
)
TRANSFORMS
.
_module_dict
.
pop
(
'PointSample'
,
None
)
assert
'PointSample'
not
in
TRANSFORMS
.
module_dict
register_all_modules
(
init_default_scope
=
True
)
assert
'PointSample'
in
TRANSFORMS
.
module_dict
assert
DefaultScope
.
get_current_instance
().
scope_name
==
'mmdet3d'
def
test_setup_multi_processes
():
...
...
tools/test.py
View file @
a34823dc
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
warnings
import
os.path
as
osp
import
mmcv
import
torch
from
mmcv
import
Config
,
DictAction
from
mmcv.cnn
import
fuse_conv_bn
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
get_dist_info
,
init_dist
,
load_checkpoint
,
wrap_fp16_model
)
from
mmengine.config
import
Config
,
DictAction
from
mmengine.runner
import
Runner
import
mmdet
from
mmdet3d.apis
import
single_gpu_test
from
mmdet3d.datasets
import
build_dataloader
,
build_dataset
from
mmdet3d.models
import
build_model
from
mmdet.apis
import
multi_gpu_test
,
set_random_seed
from
mmdet.datasets
import
replace_ImageToTensor
if
mmdet
.
__version__
>
'2.23.0'
:
# If mmdet version > 2.23.0, setup_multi_processes would be imported and
# used from mmdet instead of mmdet3d.
from
mmdet.utils
import
setup_multi_processes
else
:
from
mmdet3d.utils
import
setup_multi_processes
try
:
# If mmdet version > 2.23.0, compat_cfg would be imported and
# used from mmdet instead of mmdet3d.
from
mmdet.utils
import
compat_cfg
except
ImportError
:
from
mmdet3d.utils
import
compat_cfg
from
mmdet3d.utils
import
register_all_modules
# TODO: support fuse_conv_bn, visualization, and format_only
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'MMDet test (and eval) a model'
)
description
=
'MMDet
3D
test (and eval) a model'
)
parser
.
add_argument
(
'config'
,
help
=
'test config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--out'
,
help
=
'output result file in pickle format'
)
parser
.
add_argument
(
'--fuse-conv-bn'
,
action
=
'store_true'
,
help
=
'Whether to fuse conv and bn, this will slightly increase'
'the inference speed'
)
parser
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--gpu-id'
,
type
=
int
,
default
=
0
,
help
=
'id of gpu to use '
'(only applicable to non-distributed testing)'
)
parser
.
add_argument
(
'--format-only'
,
action
=
'store_true'
,
help
=
'Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server'
)
parser
.
add_argument
(
'--eval'
,
type
=
str
,
nargs
=
'+'
,
help
=
'evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
parser
.
add_argument
(
'--show-dir'
,
help
=
'directory where results will be saved'
)
parser
.
add_argument
(
'--gpu-collect'
,
action
=
'store_true'
,
help
=
'whether to use gpu to collect results.'
)
parser
.
add_argument
(
'--tmpdir'
,
help
=
'tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
'--work-dir'
,
help
=
'the directory to save the file containing evaluation metrics'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
...
...
@@ -94,19 +28,6 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.'
)
parser
.
add_argument
(
'--eval-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
...
...
@@ -116,144 +37,38 @@ def parse_args():
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
eval_options
:
raise
ValueError
(
'--options and --eval-options cannot be both specified, '
'--options is deprecated in favor of --eval-options'
)
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --eval-options'
)
args
.
eval_options
=
args
.
options
return
args
def
main
():
args
=
parse_args
()
assert
args
.
out
or
args
.
eval
or
args
.
format_only
or
args
.
show
\
or
args
.
show_dir
,
\
(
'Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"'
)
if
args
.
eval
and
args
.
format_only
:
raise
ValueError
(
'--eval and --format_only cannot be both specified'
)
if
args
.
out
is
not
None
and
not
args
.
out
.
endswith
((
'.pkl'
,
'.pickle'
)):
raise
ValueError
(
'The output file must be a pkl file.'
)
# register all modules in mmdet3d into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules
(
init_default_scope
=
False
)
# load config
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
.
launcher
=
args
.
launcher
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
cfg
=
compat_cfg
(
cfg
)
# set multi-process settings
setup_multi_processes
(
cfg
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
cfg
.
model
.
pretrained
=
None
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
[
0
:
1
]
warnings
.
warn
(
'`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed testing. Use the first GPU '
'in `gpu_ids` now.'
)
else
:
cfg
.
gpu_ids
=
[
args
.
gpu_id
]
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
test_dataloader_default_args
=
dict
(
samples_per_gpu
=
1
,
workers_per_gpu
=
2
,
dist
=
distributed
,
shuffle
=
False
)
# in case the test dataset is concatenated
if
isinstance
(
cfg
.
data
.
test
,
dict
):
cfg
.
data
.
test
.
test_mode
=
True
if
cfg
.
data
.
test_dataloader
.
get
(
'samples_per_gpu'
,
1
)
>
1
:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg
.
data
.
test
.
pipeline
=
replace_ImageToTensor
(
cfg
.
data
.
test
.
pipeline
)
elif
isinstance
(
cfg
.
data
.
test
,
list
):
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
test_mode
=
True
if
cfg
.
data
.
test_dataloader
.
get
(
'samples_per_gpu'
,
1
)
>
1
:
for
ds_cfg
in
cfg
.
data
.
test
:
ds_cfg
.
pipeline
=
replace_ImageToTensor
(
ds_cfg
.
pipeline
)
test_loader_cfg
=
{
**
test_dataloader_default_args
,
**
cfg
.
data
.
get
(
'test_dataloader'
,
{})
}
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the dataloader
dataset
=
build_dataset
(
cfg
.
data
.
test
)
data_loader
=
build_dataloader
(
dataset
,
**
test_loader_cfg
)
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
# build the model and load checkpoint
cfg
.
model
.
train_cfg
=
None
model
=
build_model
(
cfg
.
model
,
test_cfg
=
cfg
.
get
(
'test_cfg'
))
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
if
fp16_cfg
is
not
None
:
wrap_fp16_model
(
model
)
checkpoint
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
if
args
.
fuse_conv_bn
:
model
=
fuse_conv_bn
(
model
)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if
'CLASSES'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
CLASSES
=
checkpoint
[
'meta'
][
'CLASSES'
]
else
:
model
.
CLASSES
=
dataset
.
CLASSES
# palette for visualization in segmentation tasks
if
'PALETTE'
in
checkpoint
.
get
(
'meta'
,
{}):
model
.
PALETTE
=
checkpoint
[
'meta'
][
'PALETTE'
]
elif
hasattr
(
dataset
,
'PALETTE'
):
# segmentation dataset has `PALETTE` attribute
model
.
PALETTE
=
dataset
.
PALETTE
cfg
.
load_from
=
args
.
checkpoint
if
not
distributed
:
model
=
MMDataParallel
(
model
,
device_ids
=
cfg
.
gpu_ids
)
outputs
=
single_gpu_test
(
model
,
data_loader
,
args
.
show
,
args
.
show_dir
)
else
:
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
)
outputs
=
multi_gpu_test
(
model
,
data_loader
,
args
.
tmpdir
,
args
.
gpu_collect
)
# build the runner from config
runner
=
Runner
.
from_cfg
(
cfg
)
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
if
args
.
out
:
print
(
f
'
\n
writing results to
{
args
.
out
}
'
)
mmcv
.
dump
(
outputs
,
args
.
out
)
kwargs
=
{}
if
args
.
eval_options
is
None
else
args
.
eval_options
if
args
.
format_only
:
dataset
.
format_results
(
outputs
,
**
kwargs
)
if
args
.
eval
:
eval_kwargs
=
cfg
.
get
(
'evaluation'
,
{}).
copy
()
# hard-code way to remove EvalHook args
for
key
in
[
'interval'
,
'tmpdir'
,
'start'
,
'gpu_collect'
,
'save_best'
,
'rule'
]:
eval_kwargs
.
pop
(
key
,
None
)
eval_kwargs
.
update
(
dict
(
metric
=
args
.
eval
,
**
kwargs
))
print
(
dataset
.
evaluate
(
outputs
,
**
eval_kwargs
))
# start testing
runner
.
test
()
if
__name__
==
'__main__'
:
...
...
tools/train.py
View file @
a34823dc
# Copyright (c) OpenMMLab. All rights reserved.
from
__future__
import
division
import
argparse
import
copy
import
os
import
time
import
warnings
from
os
import
path
as
osp
import
mmcv
import
torch
import
torch.distributed
as
dist
from
mmcv
import
Config
,
DictAction
from
mm
cv.run
ne
r
import
get_dist_info
,
init_dist
from
mm
engi
ne
import
Runner
from
mmdet
import
__version__
as
mmdet_version
from
mmdet3d
import
__version__
as
mmdet3d_version
from
mmdet3d.apis
import
init_random_seed
,
train_model
from
mmdet3d.datasets
import
build_dataset
from
mmdet3d.models
import
build_model
from
mmdet3d.utils
import
collect_env
,
get_root_logger
from
mmdet.apis
import
set_random_seed
from
mmseg
import
__version__
as
mmseg_version
try
:
# If mmdet version > 2.20.0, setup_multi_processes would be imported and
# used from mmdet instead of mmdet3d.
from
mmdet.utils
import
setup_multi_processes
except
ImportError
:
from
mmdet3d.utils
import
setup_multi_processes
from
mmdet3d.utils
import
register_all_modules
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--auto-resume'
,
action
=
'store_true'
,
help
=
'resume from the latest checkpoint automatically'
)
parser
.
add_argument
(
'--no-validate'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-id'
,
type
=
int
,
default
=
0
,
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--diff-seed'
,
action
=
'store_true'
,
help
=
'Whether or not set different seeds for different ranks'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
...
...
@@ -94,39 +30,24 @@ def parse_args():
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--autoscale-lr'
,
action
=
'store_true'
,
help
=
'automatically scale lr with the number of gpus'
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
if
args
.
options
and
args
.
cfg_options
:
raise
ValueError
(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options'
)
if
args
.
options
:
warnings
.
warn
(
'--options is deprecated in favor of --cfg-options'
)
args
.
cfg_options
=
args
.
options
return
args
def
main
():
args
=
parse_args
()
# register all modules in mmdet3d into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules
(
init_default_scope
=
False
)
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# set multi-process settings
setup_multi_processes
(
cfg
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
...
...
@@ -135,128 +56,12 @@ def main():
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
if
args
.
auto_resume
:
cfg
.
auto_resume
=
args
.
auto_resume
warnings
.
warn
(
'`--auto-resume` is only supported when mmdet'
'version >= 2.20.0 for 3D detection model or'
'mmsegmentation verision >= 0.21.0 for 3D'
'segmentation model'
)
if
args
.
gpus
is
not
None
:
cfg
.
gpu_ids
=
range
(
1
)
warnings
.
warn
(
'`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.'
)
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
[
0
:
1
]
warnings
.
warn
(
'`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.'
)
if
args
.
gpus
is
None
and
args
.
gpu_ids
is
None
:
cfg
.
gpu_ids
=
[
args
.
gpu_id
]
if
args
.
autoscale_lr
:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg
.
optimizer
[
'lr'
]
=
cfg
.
optimizer
[
'lr'
]
*
len
(
cfg
.
gpu_ids
)
/
8
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if
cfg
.
model
.
type
in
[
'EncoderDecoder3D'
]:
logger_name
=
'mmseg'
else
:
logger_name
=
'mmdet'
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
,
name
=
logger_name
)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
seed
=
init_random_seed
(
args
.
seed
)
seed
=
seed
+
dist
.
get_rank
()
if
args
.
diff_seed
else
seed
logger
.
info
(
f
'Set random seed to
{
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
'
)
set_random_seed
(
seed
,
deterministic
=
args
.
deterministic
)
cfg
.
seed
=
seed
meta
[
'seed'
]
=
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
get
(
'train_cfg'
),
test_cfg
=
cfg
.
get
(
'test_cfg'
))
model
.
init_weights
()
# build the runner from config
runner
=
Runner
.
from_cfg
(
cfg
)
logger
.
info
(
f
'Model:
\n
{
model
}
'
)
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
# in case we use a dataset wrapper
if
'dataset'
in
cfg
.
data
.
train
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
dataset
.
pipeline
else
:
val_dataset
.
pipeline
=
cfg
.
data
.
train
.
pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset
.
test_mode
=
False
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
mmdet_version
,
mmseg_version
=
mmseg_version
,
mmdet3d_version
=
mmdet3d_version
,
config
=
cfg
.
pretty_text
,
CLASSES
=
datasets
[
0
].
CLASSES
,
PALETTE
=
datasets
[
0
].
PALETTE
# for segmentors
if
hasattr
(
datasets
[
0
],
'PALETTE'
)
else
None
)
# add an attribute for visualization convenience
model
.
CLASSES
=
datasets
[
0
].
CLASSES
train_model
(
model
,
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
meta
=
meta
)
# start training
runner
.
train
()
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment