global_vars.py 3.84 KB
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
import time
2

mandoxzhang's avatar
mandoxzhang committed
3
import torch
4

mandoxzhang's avatar
mandoxzhang committed
5
6
7
8
9
10
11
12
13
14
from .WandbLog import TensorboardLog

_GLOBAL_TIMERS = None
_GLOBAL_TENSORBOARD_WRITER = None


def set_global_variables(launch_time, tensorboard_path):
    _set_timers()
    _set_tensorboard_writer(launch_time, tensorboard_path)

15

mandoxzhang's avatar
mandoxzhang committed
16
17
18
19
20
21
def _set_timers():
    """Initialize timers."""
    global _GLOBAL_TIMERS
    _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
    _GLOBAL_TIMERS = Timers()

22

mandoxzhang's avatar
mandoxzhang committed
23
24
25
def _set_tensorboard_writer(launch_time, tensorboard_path):
    """Set tensorboard writer."""
    global _GLOBAL_TENSORBOARD_WRITER
26
    _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer')
mandoxzhang's avatar
mandoxzhang committed
27
28
    if torch.distributed.get_rank() == 0:
        _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
29
30


mandoxzhang's avatar
mandoxzhang committed
31
32
33
34
35
def get_timers():
    """Return timers."""
    _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
    return _GLOBAL_TIMERS

36

mandoxzhang's avatar
mandoxzhang committed
37
38
39
40
41
def get_tensorboard_writer():
    """Return tensorboard writer. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_TENSORBOARD_WRITER

42

mandoxzhang's avatar
mandoxzhang committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def _ensure_var_is_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is not None, '{} is not initialized.'.format(name)


def _ensure_var_is_not_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is None, '{} is already initialized.'.format(name)


class _Timer:
    """Timer."""

    def __init__(self, name):
        self.name_ = name
        self.elapsed_ = 0.0
        self.started_ = False
        self.start_time = time.time()

    def start(self):
        """Start the timer."""
        # assert not self.started_, 'timer has already been started'
        torch.cuda.synchronize()
        self.start_time = time.time()
        self.started_ = True

    def stop(self):
        """Stop the timer."""
        assert self.started_, 'timer is not started'
        torch.cuda.synchronize()
        self.elapsed_ += (time.time() - self.start_time)
        self.started_ = False

    def reset(self):
        """Reset timer."""
        self.elapsed_ = 0.0
        self.started_ = False

    def elapsed(self, reset=True):
        """Calculate the elapsed time."""
        started_ = self.started_
        # If the timing in progress, end it first.
        if self.started_:
            self.stop()
        # Get the elapsed time.
        elapsed_ = self.elapsed_
        # Reset the elapsed time
        if reset:
            self.reset()
        # If timing was in progress, set it back.
        if started_:
            self.start()
        return elapsed_


class Timers:
    """Group of timers."""

    def __init__(self):
        self.timers = {}

    def __call__(self, name):
        if name not in self.timers:
            self.timers[name] = _Timer(name)
        return self.timers[name]

    def write(self, names, writer, iteration, normalizer=1.0, reset=False):
        """Write timers to a tensorboard writer"""
        # currently when using add_scalars,
        # torch.utils.add_scalars makes each timer its own run, which
        # polutes the runs list, so we just add each as a scalar
        assert normalizer > 0.0
        for name in names:
            value = self.timers[name].elapsed(reset=reset) / normalizer
            writer.add_scalar(name + '-time', value, iteration)

    def log(self, names, normalizer=1.0, reset=True):
        """Log a group of timers."""
        assert normalizer > 0.0
        string = 'time (ms)'
        for name in names:
124
            elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
mandoxzhang's avatar
mandoxzhang committed
125
126
            string += ' | {}: {:.2f}'.format(name, elapsed_time)
        if torch.distributed.is_initialized():
127
            if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
mandoxzhang's avatar
mandoxzhang committed
128
129
130
                print(string, flush=True)
        else:
            print(string, flush=True)