setup_env.py 4.18 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2

3
import datetime
4
5
6
import os
import platform
import warnings
7
8

import cv2
VVsssssk's avatar
VVsssssk committed
9
from mmengine import DefaultScope
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch import multiprocessing as mp


def setup_multi_processes(cfg):
    """Setup multi-processing environment variables."""
    # set multi-process start method as `fork` to speed up the training
    if platform.system() != 'Windows':
        mp_start_method = cfg.get('mp_start_method', 'fork')
        current_method = mp.get_start_method(allow_none=True)
        if current_method is not None and current_method != mp_start_method:
            warnings.warn(
                f'Multi-processing start method `{mp_start_method}` is '
                f'different from the previous setting `{current_method}`.'
                f'It will be force set to `{mp_start_method}`. You can change '
                f'this behavior by changing `mp_start_method` in your config.')
        mp.set_start_method(mp_start_method, force=True)

    # disable opencv multithreading to avoid system being overloaded
    opencv_num_threads = cfg.get('opencv_num_threads', 0)
    cv2.setNumThreads(opencv_num_threads)

    # setup OMP threads
    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
33
34
35
36
37
38
39
    workers_per_gpu = cfg.data.get('workers_per_gpu', 1)
    if 'train_dataloader' in cfg.data:
        workers_per_gpu = \
            max(cfg.data.train_dataloader.get('workers_per_gpu', 1),
                workers_per_gpu)

    if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
40
41
42
43
44
45
46
47
48
        omp_num_threads = 1
        warnings.warn(
            f'Setting OMP_NUM_THREADS environment variable for each process '
            f'to be {omp_num_threads} in default, to avoid your system being '
            f'overloaded, please further tune the variable for optimal '
            f'performance in your application as needed.')
        os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)

    # setup MKL threads
49
    if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
50
51
52
53
54
55
56
        mkl_num_threads = 1
        warnings.warn(
            f'Setting MKL_NUM_THREADS environment variable for each process '
            f'to be {mkl_num_threads} in default, to avoid your system being '
            f'overloaded, please further tune the variable for optimal '
            f'performance in your application as needed.')
        os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
VVsssssk's avatar
VVsssssk committed
57
58
59
60
61
62


def register_all_modules(init_default_scope: bool = True) -> None:
    """Register all modules in mmdet3d into the registries.

    Args:
63
        init_default_scope (bool): Whether initialize the mmdet default scope.
VVsssssk's avatar
VVsssssk committed
64
65
66
67
68
69
70
            When `init_default_scope=True`, the global default scope will be
            set to `mmdet3d`, and all registries will build modules from mmdet3d's
            registry node. To understand more about the registry, please refer
            to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
            Defaults to True.
    """  # noqa
    import mmdet3d.datasets  # noqa: F401,F403
zhangshilong's avatar
zhangshilong committed
71
72
73
74
    import mmdet3d.engine.scheduler  # noqa: F401,F403
    import mmdet3d.evaluation.metrics  # noqa: F401,F403
    import mmdet3d.structures  # noqa: F401,F403
    import mmdet3d.visualization  # noqa: F401,F403
VVsssssk's avatar
VVsssssk committed
75
    if init_default_scope:
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        never_created = DefaultScope.get_current_instance() is None \
                        or not DefaultScope.check_instance_created('mmdet3d')
        if never_created:
            DefaultScope.get_instance('mmdet3d', scope_name='mmdet3d')
            return
        current_scope = DefaultScope.get_current_instance()
        if current_scope.scope_name != 'mmdet3d':
            warnings.warn('The current default scope '
                          f'"{current_scope.scope_name}" is not "mmdet3d", '
                          '`register_all_modules` will force the current'
                          'default scope to be "mmdet3d". If this is not '
                          'expected, please set `init_default_scope=False`.')
            # avoid name conflict
            new_instance_name = f'mmdet3d-{datetime.datetime.now()}'
            DefaultScope.get_instance(new_instance_name, scope_name='mmdet3d')