global_vars.py 13.6 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron global variables."""

mohammad's avatar
mohammad committed
18
19
from abc import ABC
from abc import abstractmethod
20
import math
Mohammad's avatar
Mohammad committed
21
22
import os
import sys
Mohammad's avatar
Mohammad committed
23
24
import time

25
import numpy as np
Mohammad's avatar
Mohammad committed
26
import torch
Mohammad's avatar
Mohammad committed
27

28
from megatron.tokenizer import build_tokenizer
Mohammad's avatar
Mohammad committed
29
30
31
from .arguments import parse_args

_GLOBAL_ARGS = None
mohammad's avatar
mohammad committed
32
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
Mohammad's avatar
Mohammad committed
33
34
35
36
37
38
39
40
41
42
43
44
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None


def get_args():
    """Return arguments."""
    _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
    return _GLOBAL_ARGS


mohammad's avatar
mohammad committed
45
46
47
48
49
50
51
52
def get_num_microbatches():
    return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()


def update_num_microbatches(consumed_samples):
    _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples)


Mohammad's avatar
Mohammad committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def get_tokenizer():
    """Return tokenizer."""
    _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
    return _GLOBAL_TOKENIZER


def get_tensorboard_writer():
    """Return tensorboard writer. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_TENSORBOARD_WRITER


def get_adlr_autoresume():
    """ADLR autoresume object. It can be None so no need
    to check if it is initialized."""
    return _GLOBAL_ADLR_AUTORESUME


def get_timers():
    """Return timers."""
    _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
    return _GLOBAL_TIMERS


77
78
def set_global_variables(extra_args_provider=None, args_defaults={},
                         ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
79
    """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
Mohammad's avatar
Mohammad committed
80
    args = _parse_args(extra_args_provider=extra_args_provider,
81
82
                       defaults=args_defaults,
                       ignore_unknown_args=ignore_unknown_args)
mohammad's avatar
mohammad committed
83
    _build_num_microbatches_calculator(args)
Mohammad's avatar
Mohammad committed
84
    _ = _build_tokenizer(args)
Mohammad's avatar
Mohammad committed
85
86
    _set_tensorboard_writer(args)
    _set_adlr_autoresume(args)
Mohammad's avatar
Mohammad committed
87
88
89
    _set_timers()


90
91
def _parse_args(extra_args_provider=None, defaults={},
                ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
92
93
94
    """Parse entire arguments."""
    global _GLOBAL_ARGS
    _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
Mohammad's avatar
Mohammad committed
95
    _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
96
97
                              defaults=defaults,
                              ignore_unknown_args=ignore_unknown_args)
Mohammad's avatar
Mohammad committed
98
    return _GLOBAL_ARGS
Mohammad's avatar
Mohammad committed
99
100


mohammad's avatar
mohammad committed
101
102
103
104
105
106
107
108
109
def _build_num_microbatches_calculator(args):

    global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
    _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
                                   'num microbatches calculator')

    # Constant num micro-batches.
    if args.rampup_batch_size is None:
        micro_batch_times_data_parallel = args.micro_batch_size * \
110
                                          args.data_parallel_size
mohammad's avatar
mohammad committed
111
112
113
114
115
116
117
118
119
120
121
122
123
        assert args.global_batch_size % micro_batch_times_data_parallel == 0, \
            'global batch size ({}) is not divisible by micro batch size ({})' \
            ' times data parallel size ({})'.format(args.global_batch_size,
                                                    args.micro_batch_size,
                                                    args.data_parallel_size)
        num_micro_batches = args.global_batch_size // \
                            micro_batch_times_data_parallel
        if args.rank == 0:
            print('setting number of micro-batches to constant {}'.format(
                num_micro_batches), flush=True)
        _GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
            num_micro_batches)

mohammad's avatar
mohammad committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    else:
        assert len(args.rampup_batch_size) == 3, 'expected the following ' \
            'format: --rampup-batch-size <start batch size> ' \
            '<batch size incerement> <ramp-up samples>'
        start_batch_size = int(args.rampup_batch_size[0])
        batch_size_increment = int(args.rampup_batch_size[1])
        ramup_samples = int(args.rampup_batch_size[2])
        if args.rank == 0:
            print('will use batch size rampup starting from global batch '
                  'size {} to global batch size {} with batch size increments '
                  '{} over {} samples.'.format(start_batch_size,
                                               args.global_batch_size,
                                               batch_size_increment,
                                               ramup_samples), flush=True)
        _GLOBAL_NUM_MICROBATCHES_CALCULATOR = RampupBatchsizeNumMicroBatches(
            start_batch_size, batch_size_increment, ramup_samples,
            args.global_batch_size, args.micro_batch_size,
            args.data_parallel_size)
mohammad's avatar
mohammad committed
142
143
144
145
146
147


class NumMicroBatchesCalculator(ABC):

    def __init__(self, name):
        self.name = name
148
        self.num_micro_batches = None
mohammad's avatar
mohammad committed
149
150
151
        super(NumMicroBatchesCalculator, self).__init__()

    def get(self):
152
        return self.num_micro_batches
mohammad's avatar
mohammad committed
153

mohammad's avatar
mohammad committed
154
    @abstractmethod
mohammad's avatar
mohammad committed
155
156
157
158
159
160
161
    def update(self, consumed_samples):
        pass


class ConstantNumMicroBatches(NumMicroBatchesCalculator):

    def __init__(self, num_micro_batches=1):
mohammad's avatar
mohammad committed
162
163
        super(ConstantNumMicroBatches, self).__init__(
            'constant: {}'.format(num_micro_batches))
mohammad's avatar
mohammad committed
164
165
166
167
168
169
170
        assert num_micro_batches >= 1
        self.num_micro_batches = num_micro_batches

    def update(self, consumed_samples):
        pass


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):

    def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
                 global_batch_size, micro_batch_size, data_parallel_size):
        """Batch size ramp up.
        Over 
          steps = (global-batch-size - start-batch-size) / batch_size_increment
        increment batch size from start-batch-size to global-batch-size using
          rampup-samples / steps
        samples.
        Arguments:
            start_batch_size: global batch size to start with
            batch_size_increment: global batch size increments
            ramup_samples: number of samples to use ramp up global
               batch size from `start_batch_size` to `global_batch_size`
            global_batch_size: global batch size post rampup
            micro_batch_size: micro batch size
            data_parallel_size: data parallel size.
        """

mohammad's avatar
mohammad committed
191
192
193
194
        super(RampupBatchsizeNumMicroBatches, self).__init__(
            'batch size ramup: {}, {}, {}'.format(
                start_batch_size, batch_size_increment, ramup_samples))
        
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        self.micro_batch_size = micro_batch_size
        self.data_parallel_size = data_parallel_size
        self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
                                                    self.data_parallel_size
        assert self.micro_batch_times_data_parallel_size > 0
        
        assert start_batch_size > 0
        self.start_batch_size = start_batch_size

        assert global_batch_size > 0
        self.global_batch_size = global_batch_size
        diff_batch_size = self.global_batch_size - self.start_batch_size
        assert diff_batch_size >= 0
        assert batch_size_increment > 0
        self.batch_size_increment = batch_size_increment
        assert diff_batch_size % batch_size_increment == 0, 'expected ' \
            'global batch size interval ({}) to be divisible by global batch ' \
            'size increment ({})'.format(diff_batch_size, batch_size_increment)

        num_increments = diff_batch_size // self.batch_size_increment
        assert ramup_samples >= 0
        self.rampup_samples_per_increment = ramup_samples / num_increments

        # Initialize number of microbatches.
        self.update(0)


    def update(self, consumed_samples):

        steps = int(consumed_samples / self.rampup_samples_per_increment)
        current_global_batch_size = self.start_batch_size + \
                                    steps * self.batch_size_increment
        current_global_batch_size = min(current_global_batch_size,
                                        self.global_batch_size)
        
        assert current_global_batch_size % \
            self.micro_batch_times_data_parallel_size == 0, 'current global ' \
            'batch size ({}) is not divisible by micro-batch-size ({}) times' \
            'data parallel size ({})'.format(current_global_batch_size,
                                             self.micro_batch_size,
                                             self.data_parallel_size)
        self.num_micro_batches = current_global_batch_size // \
                                 self.micro_batch_times_data_parallel_size
mohammad's avatar
mohammad committed
238
239


Mohammad's avatar
Mohammad committed
240
def _build_tokenizer(args):
Mohammad's avatar
Mohammad committed
241
242
243
    """Initialize tokenizer."""
    global _GLOBAL_TOKENIZER
    _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
Mohammad's avatar
Mohammad committed
244
    _GLOBAL_TOKENIZER = build_tokenizer(args)
Mohammad's avatar
Mohammad committed
245
246
247
248
249
250
251
    return _GLOBAL_TOKENIZER


def rebuild_tokenizer(args):
    global _GLOBAL_TOKENIZER
    _GLOBAL_TOKENIZER = None
    return _build_tokenizer(args)
Mohammad's avatar
Mohammad committed
252
253


Mohammad's avatar
Mohammad committed
254
def _set_tensorboard_writer(args):
Mohammad's avatar
Mohammad committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    """Set tensorboard writer."""
    global _GLOBAL_TENSORBOARD_WRITER
    _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
                                   'tensorboard writer')

    if hasattr(args, 'tensorboard_dir') and \
       args.tensorboard_dir and args.rank == 0:
        try:
            from torch.utils.tensorboard import SummaryWriter
            print('> setting tensorboard ...')
            _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
                log_dir=args.tensorboard_dir)
        except ModuleNotFoundError:
            print('WARNING: TensorBoard writing requested but is not '
                  'available (are you using PyTorch 1.1.0 or later?), '
                  'no TensorBoard logs will be written.', flush=True)


Mohammad's avatar
Mohammad committed
273
def _set_adlr_autoresume(args):
Mohammad's avatar
Mohammad committed
274
275
276
277
278
279
280
281
282
283
    """Initialize ADLR autoresume."""
    global _GLOBAL_ADLR_AUTORESUME
    _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')

    if args.adlr_autoresume:
        if args.rank == 0:
            print('enabling autoresume ...', flush=True)
        sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
        try:
            from userlib.auto_resume import AutoResume
Neel Kant's avatar
Neel Kant committed
284
        except BaseException:
Mohammad's avatar
Mohammad committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
            print('ADLR autoresume is not available, exiting ...')
            sys.exit()

        _GLOBAL_ADLR_AUTORESUME = AutoResume


def _set_timers():
    """Initialize timers."""
    global _GLOBAL_TIMERS
    _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
    _GLOBAL_TIMERS = Timers()


def _ensure_var_is_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is not None, '{} is not initialized.'.format(name)


def _ensure_var_is_not_initialized(var, name):
    """Make sure the input variable is not None."""
    assert var is None, '{} is already initialized.'.format(name)
Mohammad's avatar
Mohammad committed
306
307


308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
class _Timer:
    """Timer."""

    def __init__(self, name):
        self.name_ = name
        self.elapsed_ = 0.0
        self.started_ = False
        self.start_time = time.time()

    def start(self):
        """Start the timer."""
        assert not self.started_, 'timer has already been started'
        torch.cuda.synchronize()
        self.start_time = time.time()
        self.started_ = True

    def stop(self):
        """Stop the timer."""
        assert self.started_, 'timer is not started'
        torch.cuda.synchronize()
        self.elapsed_ += (time.time() - self.start_time)
        self.started_ = False

    def reset(self):
        """Reset timer."""
        self.elapsed_ = 0.0
        self.started_ = False

    def elapsed(self, reset=True):
        """Calculate the elapsed time."""
        started_ = self.started_
        # If the timing in progress, end it first.
        if self.started_:
            self.stop()
        # Get the elapsed time.
        elapsed_ = self.elapsed_
        # Reset the elapsed time
        if reset:
            self.reset()
        # If timing was in progress, set it back.
        if started_:
            self.start()
        return elapsed_


Mohammad's avatar
Mohammad committed
353
354
355
356
357
358
359
360
class Timers:
    """Group of timers."""

    def __init__(self):
        self.timers = {}

    def __call__(self, name):
        if name not in self.timers:
361
            self.timers[name] = _Timer(name)
Mohammad's avatar
Mohammad committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        return self.timers[name]

    def write(self, names, writer, iteration, normalizer=1.0, reset=False):
        """Write timers to a tensorboard writer"""
        # currently when using add_scalars,
        # torch.utils.add_scalars makes each timer its own run, which
        # polutes the runs list, so we just add each as a scalar
        assert normalizer > 0.0
        for name in names:
            value = self.timers[name].elapsed(reset=reset) / normalizer
            writer.add_scalar(name + '_time', value, iteration)

    def log(self, names, normalizer=1.0, reset=True):
        """Log a group of timers."""
        assert normalizer > 0.0
        string = 'time (ms)'
        for name in names:
            elapsed_time = self.timers[name].elapsed(
380
                reset=reset) * 1000.0 / normalizer
Mohammad's avatar
Mohammad committed
381
382
383
384
385
386
            string += ' | {}: {:.2f}'.format(name, elapsed_time)
        if torch.distributed.is_initialized():
            if torch.distributed.get_rank() == 0:
                print(string, flush=True)
        else:
            print(string, flush=True)