comm.py 2.56 KB
Newer Older
lishj6's avatar
init  
lishj6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright 2021 Toyota Research Institute.  All rights reserved.
import logging
from functools import wraps

import torch.distributed as dist

from detectron2.utils import comm as d2_comm

LOG = logging.getLogger(__name__)

_NESTED_BROADCAST_FROM_MASTER = False


def is_distributed():
    return d2_comm.get_world_size() > 1


def broadcast_from_master(fn):
    """If distributed, only the master executes the function and broadcast the results to other workers.

    Usage:
    @broadcast_from_master
    def foo(a, b): ...
    """
    @wraps(fn)
    def wrapper(*args, **kwargs):  # pylint: disable=unused-argument
        global _NESTED_BROADCAST_FROM_MASTER

        if not is_distributed():
            return fn(*args, **kwargs)

        if _NESTED_BROADCAST_FROM_MASTER:
            assert d2_comm.is_main_process()
            LOG.warning(f"_NESTED_BROADCAST_FROM_MASTER = True, {fn.__name__}")
            return fn(*args, **kwargs)

        if d2_comm.is_main_process():
            _NESTED_BROADCAST_FROM_MASTER = True
            ret = [fn(*args, **kwargs), ]
            _NESTED_BROADCAST_FROM_MASTER = False
        else:
            ret = [None, ]
        if dist.is_initialized():
            dist.broadcast_object_list(ret)
        ret = ret[0]

        assert ret is not None
        return ret

    return wrapper


def master_only(fn):
    """If distributed, only the master executes the function.

    Usage:
    @master_only
    def foo(a, b): ...
    """
    @wraps(fn)
    def wrapped_fn(*args, **kwargs):
        if d2_comm.is_main_process():
            ret = fn(*args, **kwargs)
        d2_comm.synchronize()
        if d2_comm.is_main_process():
            return ret

    return wrapped_fn


def gather_dict(dikt):
    """Gather python dictionaries from all workers to the rank=0 worker.

    Assumption: the keys of `dikt` are disjoint across all workers.

    If rank = 0, then returned aggregated dict.
    If rank > 0, then return `None`.
    """
    dict_lst = d2_comm.gather(dikt, dst=0)
    if d2_comm.is_main_process():
        gathered_dict = {}
        for dic in dict_lst:
            for k in dic.keys():
                assert k not in gathered_dict, f"Dictionary key overlaps: {k}"
            gathered_dict.update(dic)
        return gathered_dict
    else:
        return None


def reduce_sum(tensor):
    """
    Adapted from AdelaiDet:
        https://github.com/aim-uofa/AdelaiDet/blob/master/adet/utils/comm.py
    """
    if not is_distributed():
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    return tensor