Commit a34823dc authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]fix train.py and test.py

parent 5c5e459b
...@@ -4,11 +4,9 @@ from mmcv.utils import Registry, build_from_cfg, print_log ...@@ -4,11 +4,9 @@ from mmcv.utils import Registry, build_from_cfg, print_log
from .collect_env import collect_env from .collect_env import collect_env
from .compat_cfg import compat_cfg from .compat_cfg import compat_cfg
from .logger import get_root_logger from .logger import get_root_logger
from .misc import find_latest_checkpoint from .setup_env import register_all_modules, setup_multi_processes
from .setup_env import setup_multi_processes
__all__ = [ __all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'print_log', 'setup_multi_processes', 'find_latest_checkpoint', 'print_log', 'setup_multi_processes', 'compat_cfg', 'register_all_modules'
'compat_cfg'
] ]
...@@ -4,6 +4,7 @@ import platform ...@@ -4,6 +4,7 @@ import platform
import warnings import warnings
import cv2 import cv2
from mmengine import DefaultScope
from torch import multiprocessing as mp from torch import multiprocessing as mp
...@@ -51,3 +52,22 @@ def setup_multi_processes(cfg): ...@@ -51,3 +52,22 @@ def setup_multi_processes(cfg):
f'overloaded, please further tune the variable for optimal ' f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.') f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmdet3d into the registries.
Args:
init_default_scope (bool): Whether initialize the mmdet3d default scope.
When `init_default_scope=True`, the global default scope will be
set to `mmdet3d`, and all registries will build modules from mmdet3d's
registry node. To understand more about the registry, please refer
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmdet3d.core # noqa: F401,F403
import mmdet3d.datasets # noqa: F401,F403
import mmdet3d.models # noqa: F401,F403
import mmdet3d.ops # noqa: F401,F403
if init_default_scope:
DefaultScope.get_instance('mmdet3d', scope_name='mmdet3d')
...@@ -2,11 +2,25 @@ ...@@ -2,11 +2,25 @@
import multiprocessing as mp import multiprocessing as mp
import os import os
import platform import platform
import sys
import cv2 import cv2
from mmcv import Config from mmcv import Config
from mmengine import DefaultScope
from mmdet3d.utils import setup_multi_processes from mmdet3d.utils import register_all_modules, setup_multi_processes
def test_register_all_modules():
from mmdet3d.registry import TRANSFORMS
sys.modules.pop('mmdet3d.datasets', None)
sys.modules.pop('mmdet3d.datasets.pipelines', None)
TRANSFORMS._module_dict.pop('PointSample', None)
assert 'PointSample' not in TRANSFORMS.module_dict
register_all_modules(init_default_scope=True)
assert 'PointSample' in TRANSFORMS.module_dict
assert DefaultScope.get_current_instance().scope_name == 'mmdet3d'
def test_setup_multi_processes(): def test_setup_multi_processes():
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import os import os
import warnings import os.path as osp
import mmcv from mmengine.config import Config, DictAction
import torch from mmengine.runner import Runner
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
import mmdet from mmdet3d.utils import register_all_modules
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor
if mmdet.__version__ > '2.23.0':
# If mmdet version > 2.23.0, setup_multi_processes would be imported and
# used from mmdet instead of mmdet3d.
from mmdet.utils import setup_multi_processes
else:
from mmdet3d.utils import setup_multi_processes
try:
# If mmdet version > 2.23.0, compat_cfg would be imported and
# used from mmdet instead of mmdet3d.
from mmdet.utils import compat_cfg
except ImportError:
from mmdet3d.utils import compat_cfg
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model') description='MMDet3D test (and eval) a model')
parser.add_argument('config', help='test config file path') parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where results will be saved')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--deterministic', '--work-dir',
action='store_true', help='the directory to save the file containing evaluation metrics')
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
...@@ -94,19 +28,6 @@ def parse_args(): ...@@ -94,19 +28,6 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
...@@ -116,144 +37,38 @@ def parse_args(): ...@@ -116,144 +37,38 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.eval_options:
raise ValueError(
'--options and --eval-options cannot be both specified, '
'--options is deprecated in favor of --eval-options')
if args.options:
warnings.warn('--options is deprecated in favor of --eval-options')
args.eval_options = args.options
return args return args
def main(): def main():
args = parse_args() args = parse_args()
assert args.out or args.eval or args.format_only or args.show \ # register all modules in mmdet3d into the registries
or args.show_dir, \ # do not init the default scope here because it will be init in the runner
('Please specify at least one operation (save/eval/format/show the ' register_all_modules(init_default_scope=False)
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
# load config
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
cfg = compat_cfg(cfg) # work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# set multi-process settings # update configs according to CLI args if args.work_dir is not None
setup_multi_processes(cfg) cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# set cudnn_benchmark # use config filename as default work_dir if cfg.work_dir is None
if cfg.get('cudnn_benchmark', False): cfg.work_dir = osp.join('./work_dirs',
torch.backends.cudnn.benchmark = True osp.splitext(osp.basename(args.config))[0])
cfg.model.pretrained = None
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed testing. Use the first GPU '
'in `gpu_ids` now.')
else:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
test_dataloader_default_args = dict(
samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)
# in case the test dataset is concatenated
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
test_loader_cfg = {
**test_dataloader_default_args,
**cfg.data.get('test_dataloader', {})
}
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(dataset, **test_loader_cfg)
# build the model and load checkpoint cfg.load_from = args.checkpoint
cfg.model.train_cfg = None
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES
# palette for visualization in segmentation tasks
if 'PALETTE' in checkpoint.get('meta', {}):
model.PALETTE = checkpoint['meta']['PALETTE']
elif hasattr(dataset, 'PALETTE'):
# segmentation dataset has `PALETTE` attribute
model.PALETTE = dataset.PALETTE
if not distributed: # build the runner from config
model = MMDataParallel(model, device_ids=cfg.gpu_ids) runner = Runner.from_cfg(cfg)
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info() # start testing
if rank == 0: runner.test()
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
print(dataset.evaluate(outputs, **eval_kwargs))
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division from __future__ import division
import argparse import argparse
import copy
import os import os
import time
import warnings
from os import path as osp from os import path as osp
import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist from mmengine import Runner
from mmdet import __version__ as mmdet_version from mmdet3d.utils import register_all_modules
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import init_random_seed, train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_model
from mmdet3d.utils import collect_env, get_root_logger
from mmdet.apis import set_random_seed
from mmseg import __version__ as mmseg_version
try:
# If mmdet version > 2.20.0, setup_multi_processes would be imported and
# used from mmdet instead of mmdet3d.
from mmdet.utils import setup_multi_processes
except ImportError:
from mmdet3d.utils import setup_multi_processes
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train a detector') parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path') parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--diff-seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
...@@ -94,39 +30,24 @@ def parse_args(): ...@@ -94,39 +30,24 @@ def parse_args():
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--autoscale-lr',
action='store_true',
help='automatically scale lr with the number of gpus')
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args return args
def main(): def main():
args = parse_args() args = parse_args()
# register all modules in mmdet3d into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename # work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None: if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None # update configs according to CLI args if args.work_dir is not None
...@@ -135,128 +56,12 @@ def main(): ...@@ -135,128 +56,12 @@ def main():
# use config filename as default work_dir if cfg.work_dir is None # use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0]) osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.auto_resume:
cfg.auto_resume = args.auto_resume
warnings.warn('`--auto-resume` is only supported when mmdet'
'version >= 2.20.0 for 3D detection model or'
'mmsegmentation verision >= 0.21.0 for 3D'
'segmentation model')
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if cfg.model.type in ['EncoderDecoder3D']:
logger_name = 'mmseg'
else:
logger_name = 'mmdet'
logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_model( # build the runner from config
cfg.model, runner = Runner.from_cfg(cfg)
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
logger.info(f'Model:\n{model}') # start training
datasets = [build_dataset(cfg.data.train)] runner.train()
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
# in case we use a dataset wrapper
if 'dataset' in cfg.data.train:
val_dataset.pipeline = cfg.data.train.dataset.pipeline
else:
val_dataset.pipeline = cfg.data.train.pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset.test_mode = False
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=mmdet_version,
mmseg_version=mmseg_version,
mmdet3d_version=mmdet3d_version,
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE # for segmentors
if hasattr(datasets[0], 'PALETTE') else None)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment