global_vars.py 6.37 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6

"""Megatron global variables."""

import os
import sys
7
8
from functools import reduce
import operator
Mohammad's avatar
Mohammad committed
9
import torch
Mohammad's avatar
Mohammad committed
10

11
from megatron import dist_signal_handler
12
from megatron.tokenizer import build_tokenizer
mohammad's avatar
mohammad committed
13
from .microbatches import build_num_microbatches_calculator
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
14
from .timers import Timers
Mohammad's avatar
Mohammad committed
15
16

_GLOBAL_ARGS = None
mohammad's avatar
mohammad committed
17
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
Mohammad's avatar
Mohammad committed
18
19
20
21
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
22
_GLOBAL_SIGNAL_HANDLER = None
23
_GLOBAL_MEMORY_BUFFER = None
Mohammad's avatar
Mohammad committed
24
25
26
27
28
29
30

def get_args():
    """Return arguments."""
    _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
    return _GLOBAL_ARGS


mohammad's avatar
mohammad committed
31
32
33
34
def get_num_microbatches():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()


35
36
37
38
39
40
41
def get_current_global_batch_size():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()


def update_num_microbatches(consumed_samples, consistency_check=True):
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
                                               consistency_check)
mohammad's avatar
mohammad committed
42
43


Mohammad's avatar
Mohammad committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def get_tokenizer():
    """Return tokenizer."""
    _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
    return _GLOBAL_TOKENIZER


def get_tensorboard_writer():
    """Return tensorboard writer. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_TENSORBOARD_WRITER


def get_adlr_autoresume():
    """ADLR autoresume object. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_ADLR_AUTORESUME


def get_timers():
    """Return timers."""
    _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
    return _GLOBAL_TIMERS

67

68
69
70
71
def get_signal_handler():
    _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
    return _GLOBAL_SIGNAL_HANDLER

72
73
74
75
76
77

def get_global_memory_buffer():
    _ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
    return _GLOBAL_MEMORY_BUFFER


78
79
80
81
def _set_signal_handler():
    global _GLOBAL_SIGNAL_HANDLER
    _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
    _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
Mohammad's avatar
Mohammad committed
82

83

84

85
def set_global_variables(args):
Mohammad's avatar
Mohammad committed
86
    """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
87
88
89
90
91
92

    assert args is not None

    _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
    set_args(args)

mohammad's avatar
mohammad committed
93
    _build_num_microbatches_calculator(args)
94
95
    if args.vocab_file:
        _ = _build_tokenizer(args)
Mohammad's avatar
Mohammad committed
96
97
    _set_tensorboard_writer(args)
    _set_adlr_autoresume(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
98
    _set_timers(args)
99
    _set_global_memory_buffer()
Mohammad's avatar
Mohammad committed
100

101
102
    if args.exit_signal_handler:
        _set_signal_handler()
103
104
    

105
106
107
def set_args(args):
    global _GLOBAL_ARGS
    _GLOBAL_ARGS = args
Mohammad's avatar
Mohammad committed
108
109


mohammad's avatar
mohammad committed
110
111
112
113
114
115
def _build_num_microbatches_calculator(args):

    global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
    _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
                                   'num microbatches calculator')

mohammad's avatar
mohammad committed
116
117
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
        args)
mohammad's avatar
mohammad committed
118
119


Mohammad's avatar
Mohammad committed
120
def _build_tokenizer(args):
Mohammad's avatar
Mohammad committed
121
122
123
    """Initialize tokenizer."""
    global _GLOBAL_TOKENIZER
    _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
Mohammad's avatar
Mohammad committed
124
    _GLOBAL_TOKENIZER = build_tokenizer(args)
Mohammad's avatar
Mohammad committed
125
126
127
128
129
130
131
    return _GLOBAL_TOKENIZER


def rebuild_tokenizer(args):
    global _GLOBAL_TOKENIZER
    _GLOBAL_TOKENIZER = None
    return _build_tokenizer(args)
Mohammad's avatar
Mohammad committed
132
133


Mohammad's avatar
Mohammad committed
134
def _set_tensorboard_writer(args):
Mohammad's avatar
Mohammad committed
135
136
137
138
139
140
    """Set tensorboard writer."""
    global _GLOBAL_TENSORBOARD_WRITER
    _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
                                   'tensorboard writer')

    if hasattr(args, 'tensorboard_dir') and \
141
       args.tensorboard_dir and args.rank == (args.world_size - 1):
Mohammad's avatar
Mohammad committed
142
143
144
145
        try:
            from torch.utils.tensorboard import SummaryWriter
            print('> setting tensorboard ...')
            _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
146
147
                log_dir=args.tensorboard_dir,
                max_queue=args.tensorboard_queue_size)
Mohammad's avatar
Mohammad committed
148
149
150
151
152
153
        except ModuleNotFoundError:
            print('WARNING: TensorBoard writing requested but is not '
                  'available (are you using PyTorch 1.1.0 or later?), '
                  'no TensorBoard logs will be written.', flush=True)


Mohammad's avatar
Mohammad committed
154
def _set_adlr_autoresume(args):
Mohammad's avatar
Mohammad committed
155
156
157
158
159
160
161
162
163
164
    """Initialize ADLR autoresume."""
    global _GLOBAL_ADLR_AUTORESUME
    _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')

    if args.adlr_autoresume:
        if args.rank == 0:
            print('enabling autoresume ...', flush=True)
        sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
        try:
            from userlib.auto_resume import AutoResume
Neel Kant's avatar
Neel Kant committed
165
        except BaseException:
Mohammad's avatar
Mohammad committed
166
167
168
169
170
171
            print('ADLR autoresume is not available, exiting ...')
            sys.exit()

        _GLOBAL_ADLR_AUTORESUME = AutoResume


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
172
def _set_timers(args):
Mohammad's avatar
Mohammad committed
173
174
175
    """Initialize timers."""
    global _GLOBAL_TIMERS
    _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
176
177
    _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)

Mohammad's avatar
Mohammad committed
178

179
180
181
182
183
184
def _set_global_memory_buffer():
    """Initialize global buffer"""
    global _GLOBAL_MEMORY_BUFFER
    _ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
    _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()

Mohammad's avatar
Mohammad committed
185
186
187
188
189
190
191
192
193

def _ensure_var_is_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is not None, '{} is not initialized.'.format(name)


def _ensure_var_is_not_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is None, '{} is already initialized.'.format(name)
Mohammad's avatar
Mohammad committed
194
195


196
197

class GlobalMemoryBuffer:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
198
199
200
    """Global buffer to avoid dynamic memory allocations.
    Caller should ensure that buffers of the same name 
    are not used concurrently."""
201
202
203
204

    def __init__(self):
        self.buffer = {}

Vijay Korthikanti's avatar
Vijay Korthikanti committed
205
    def get_tensor(self, tensor_shape, dtype, name):
206
        required_len = reduce(operator.mul, tensor_shape, 1)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
207
208
209
210
211
212
213
214
215
        if self.buffer.get((name, dtype), None) is None or \
                self.buffer[(name, dtype)].numel() < required_len:
            self.buffer[(name, dtype)] = \
                torch.empty(required_len,
                            dtype=dtype,
                            device=torch.cuda.current_device(),
                            requires_grad=False)

        return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)