global_vars.py 5.36 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6

"""Megatron global variables."""

import os
import sys
Mohammad's avatar
Mohammad committed
7
import torch
Mohammad's avatar
Mohammad committed
8

9
from megatron import dist_signal_handler
10
from megatron.tokenizer import build_tokenizer
mohammad's avatar
mohammad committed
11
from .microbatches import build_num_microbatches_calculator
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
12
from .timers import Timers
Mohammad's avatar
Mohammad committed
13
14

_GLOBAL_ARGS = None
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
15
_GLOBAL_RETRO_ARGS = None
mohammad's avatar
mohammad committed
16
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
Mohammad's avatar
Mohammad committed
17
18
19
20
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
21
_GLOBAL_SIGNAL_HANDLER = None
Mohammad's avatar
Mohammad committed
22
23
24
25
26
27
28

def get_args():
    """Return arguments."""
    _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
    return _GLOBAL_ARGS


Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
29
30
31
32
33
def get_retro_args():
    """Return retro arguments."""
    return _GLOBAL_RETRO_ARGS


mohammad's avatar
mohammad committed
34
35
36
37
def get_num_microbatches():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()


38
39
40
41
42
43
44
def get_current_global_batch_size():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()


def update_num_microbatches(consumed_samples, consistency_check=True):
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
                                               consistency_check)
mohammad's avatar
mohammad committed
45
46


Mohammad's avatar
Mohammad committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_tokenizer():
    """Return tokenizer."""
    _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
    return _GLOBAL_TOKENIZER


def get_tensorboard_writer():
    """Return tensorboard writer. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_TENSORBOARD_WRITER


def get_adlr_autoresume():
    """ADLR autoresume object. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_ADLR_AUTORESUME


def get_timers():
    """Return timers."""
    _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
    return _GLOBAL_TIMERS

70

71
72
73
74
def get_signal_handler():
    _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
    return _GLOBAL_SIGNAL_HANDLER

75

76
77
78
79
def _set_signal_handler():
    global _GLOBAL_SIGNAL_HANDLER
    _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
    _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
Mohammad's avatar
Mohammad committed
80

81

82

83
def set_global_variables(args):
Mohammad's avatar
Mohammad committed
84
    """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
85
86
87
88
89
90

    assert args is not None

    _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
    set_args(args)

mohammad's avatar
mohammad committed
91
    _build_num_microbatches_calculator(args)
92
    if args.vocab_file or args.tokenizer_model:
93
        _ = _build_tokenizer(args)
Mohammad's avatar
Mohammad committed
94
95
    _set_tensorboard_writer(args)
    _set_adlr_autoresume(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
96
    _set_timers(args)
Mohammad's avatar
Mohammad committed
97

98
99
    if args.exit_signal_handler:
        _set_signal_handler()
100
101
    

102
103
104
def set_args(args):
    global _GLOBAL_ARGS
    _GLOBAL_ARGS = args
Mohammad's avatar
Mohammad committed
105
106


Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
107
108
109
110
111
def set_retro_args(retro_args):
    global _GLOBAL_RETRO_ARGS
    _GLOBAL_RETRO_ARGS = retro_args


mohammad's avatar
mohammad committed
112
113
114
115
116
117
def _build_num_microbatches_calculator(args):

    global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
    _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
                                   'num microbatches calculator')

mohammad's avatar
mohammad committed
118
119
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
        args)
mohammad's avatar
mohammad committed
120
121


Mohammad's avatar
Mohammad committed
122
def _build_tokenizer(args):
Mohammad's avatar
Mohammad committed
123
124
125
    """Initialize tokenizer."""
    global _GLOBAL_TOKENIZER
    _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
Mohammad's avatar
Mohammad committed
126
    _GLOBAL_TOKENIZER = build_tokenizer(args)
Mohammad's avatar
Mohammad committed
127
128
129
130
131
132
133
    return _GLOBAL_TOKENIZER


def rebuild_tokenizer(args):
    global _GLOBAL_TOKENIZER
    _GLOBAL_TOKENIZER = None
    return _build_tokenizer(args)
Mohammad's avatar
Mohammad committed
134
135


Mohammad's avatar
Mohammad committed
136
def _set_tensorboard_writer(args):
Mohammad's avatar
Mohammad committed
137
138
139
140
141
142
    """Set tensorboard writer."""
    global _GLOBAL_TENSORBOARD_WRITER
    _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
                                   'tensorboard writer')

    if hasattr(args, 'tensorboard_dir') and \
143
       args.tensorboard_dir and args.rank == (args.world_size - 1):
Mohammad's avatar
Mohammad committed
144
145
146
147
        try:
            from torch.utils.tensorboard import SummaryWriter
            print('> setting tensorboard ...')
            _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
148
149
                log_dir=args.tensorboard_dir,
                max_queue=args.tensorboard_queue_size)
Mohammad's avatar
Mohammad committed
150
151
152
153
154
155
        except ModuleNotFoundError:
            print('WARNING: TensorBoard writing requested but is not '
                  'available (are you using PyTorch 1.1.0 or later?), '
                  'no TensorBoard logs will be written.', flush=True)


Mohammad's avatar
Mohammad committed
156
def _set_adlr_autoresume(args):
Mohammad's avatar
Mohammad committed
157
158
159
160
161
162
163
164
165
166
    """Initialize ADLR autoresume."""
    global _GLOBAL_ADLR_AUTORESUME
    _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')

    if args.adlr_autoresume:
        if args.rank == 0:
            print('enabling autoresume ...', flush=True)
        sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
        try:
            from userlib.auto_resume import AutoResume
Neel Kant's avatar
Neel Kant committed
167
        except BaseException:
Mohammad's avatar
Mohammad committed
168
169
170
171
172
173
            print('ADLR autoresume is not available, exiting ...')
            sys.exit()

        _GLOBAL_ADLR_AUTORESUME = AutoResume


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
174
def _set_timers(args):
Mohammad's avatar
Mohammad committed
175
176
177
    """Initialize timers."""
    global _GLOBAL_TIMERS
    _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
178
179
    _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)

Mohammad's avatar
Mohammad committed
180
181
182
183
184
185
186
187
188

def _ensure_var_is_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is not None, '{} is not initialized.'.format(name)


def _ensure_var_is_not_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is None, '{} is already initialized.'.format(name)
Mohammad's avatar
Mohammad committed
189
190


191