Unverified Commit 53488dad authored by Wenhao Wu's avatar Wenhao Wu Committed by GitHub
Browse files

[Enhance] Accelerate training (#1168)

* accelerate training

* use multi_processes

* get args from config

* reorganize env setup

* import from mmdet

* import from mmdet
parent 22f6a4fb
...@@ -16,3 +16,8 @@ work_dir = None ...@@ -16,3 +16,8 @@ work_dir = None
load_from = None load_from = None
resume_from = None resume_from = None
workflow = [('train', 1)] workflow = [('train', 1)]
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'
...@@ -3,7 +3,9 @@ from mmcv.utils import Registry, build_from_cfg, print_log ...@@ -3,7 +3,9 @@ from mmcv.utils import Registry, build_from_cfg, print_log
from .collect_env import collect_env from .collect_env import collect_env
from .logger import get_root_logger from .logger import get_root_logger
from .setup_env import setup_multi_processes
__all__ = [ __all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'print_log' 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'print_log', 'setup_multi_processes'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import os
import platform
import warnings
from torch import multiprocessing as mp
def setup_multi_processes(cfg):
"""Setup multi-processing environment variables."""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
current_method = mp.get_start_method(allow_none=True)
if current_method is not None and current_method != mp_start_method:
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can change '
f'this behavior by changing `mp_start_method` in your config.')
mp.set_start_method(mp_start_method, force=True)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import multiprocessing as mp
import os
import platform
from mmcv import Config
from mmdet3d.utils import setup_multi_processes
def test_setup_multi_processes():
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
# test config without setting env
config = dict(data=dict(workers_per_gpu=2))
cfg = Config(config)
setup_multi_processes(cfg)
assert os.getenv('OMP_NUM_THREADS') == '1'
assert os.getenv('MKL_NUM_THREADS') == '1'
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == 1
if platform.system() != 'Windows':
assert mp.get_start_method() == 'fork'
# test num workers <= 1
os.environ.pop('OMP_NUM_THREADS')
os.environ.pop('MKL_NUM_THREADS')
config = dict(data=dict(workers_per_gpu=0))
cfg = Config(config)
setup_multi_processes(cfg)
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
# test manually set env var
os.environ['OMP_NUM_THREADS'] = '4'
config = dict(data=dict(workers_per_gpu=2))
cfg = Config(config)
setup_multi_processes(cfg)
assert os.getenv('OMP_NUM_THREADS') == '4'
# test manually set opencv threads and mp start method
config = dict(
data=dict(workers_per_gpu=2),
opencv_num_threads=4,
mp_start_method='spawn')
cfg = Config(config)
setup_multi_processes(cfg)
assert cv2.getNumThreads() == 4
assert mp.get_start_method() == 'spawn'
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')
...@@ -16,6 +16,13 @@ from mmdet3d.models import build_model ...@@ -16,6 +16,13 @@ from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor from mmdet.datasets import replace_ImageToTensor
try:
# If mmdet version > 2.20.0, setup_multi_processes would be imported and
# used from mmdet instead of mmdet3d.
from mmdet.utils import setup_multi_processes
except ImportError:
from mmdet3d.utils import setup_multi_processes
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -117,6 +124,10 @@ def main(): ...@@ -117,6 +124,10 @@ def main():
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
......
...@@ -21,6 +21,13 @@ from mmdet3d.utils import collect_env, get_root_logger ...@@ -21,6 +21,13 @@ from mmdet3d.utils import collect_env, get_root_logger
from mmdet.apis import set_random_seed from mmdet.apis import set_random_seed
from mmseg import __version__ as mmseg_version from mmseg import __version__ as mmseg_version
try:
# If mmdet version > 2.20.0, setup_multi_processes would be imported and
# used from mmdet instead of mmdet3d.
from mmdet.utils import setup_multi_processes
except ImportError:
from mmdet3d.utils import setup_multi_processes
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train a detector') parser = argparse.ArgumentParser(description='Train a detector')
...@@ -98,6 +105,9 @@ def main(): ...@@ -98,6 +105,9 @@ def main():
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment