# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Megatron global variables.""" from abc import ABC from abc import abstractmethod import math import os import sys import time import numpy as np import torch from megatron.tokenizer import build_tokenizer from .arguments import parse_args _GLOBAL_ARGS = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_TOKENIZER = None _GLOBAL_TENSORBOARD_WRITER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None def get_args(): """Return arguments.""" _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') return _GLOBAL_ARGS def get_num_microbatches(): return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() def update_num_microbatches(consumed_samples): _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples) def get_tokenizer(): """Return tokenizer.""" _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') return _GLOBAL_TOKENIZER def get_tensorboard_writer(): """Return tensorboard writer. It can be None so no need to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER def get_adlr_autoresume(): """ADLR autoresume object. It can be None so no need to check if it is initialized.""" return _GLOBAL_ADLR_AUTORESUME def get_timers(): """Return timers.""" _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') return _GLOBAL_TIMERS def set_global_variables(extra_args_provider=None, args_defaults={}, ignore_unknown_args=False): """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" args = _parse_args(extra_args_provider=extra_args_provider, defaults=args_defaults, ignore_unknown_args=ignore_unknown_args) _build_num_microbatches_calculator(args) _ = _build_tokenizer(args) _set_tensorboard_writer(args) _set_adlr_autoresume(args) _set_timers() def _parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): """Parse entire arguments.""" global _GLOBAL_ARGS _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider, defaults=defaults, ignore_unknown_args=ignore_unknown_args) return _GLOBAL_ARGS def _build_num_microbatches_calculator(args): global _GLOBAL_NUM_MICROBATCHES_CALCULATOR _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, 'num microbatches calculator') # Constant num micro-batches. if args.rampup_batch_size is None: _GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches( args.global_batch_size, args.micro_batch_size, args.data_parallel_size) if args.rank == 0: print('setting number of micro-batches to constant {}'.format( _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()), flush=True) else: assert len(args.rampup_batch_size) == 3, 'expected the following ' \ 'format: --rampup-batch-size ' \ ' ' start_batch_size = int(args.rampup_batch_size[0]) batch_size_increment = int(args.rampup_batch_size[1]) ramup_samples = int(args.rampup_batch_size[2]) if args.rank == 0: print('will use batch size rampup starting from global batch ' 'size {} to global batch size {} with batch size increments ' '{} over {} samples.'.format(start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples), flush=True) _GLOBAL_NUM_MICROBATCHES_CALCULATOR = RampupBatchsizeNumMicroBatches( start_batch_size, batch_size_increment, ramup_samples, args.global_batch_size, args.micro_batch_size, args.data_parallel_size) class NumMicroBatchesCalculator(ABC): def __init__(self): self.num_micro_batches = None def get(self): return self.num_micro_batches @abstractmethod def update(self, consumed_samples): pass class ConstantNumMicroBatches(NumMicroBatchesCalculator): def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): micro_batch_times_data_parallel = micro_batch_size * \ data_parallel_size assert global_batch_size % micro_batch_times_data_parallel == 0, \ 'global batch size ({}) is not divisible by micro batch size ({})' \ ' times data parallel size ({})'.format(global_batch_size, micro_batch_size, data_parallel_size) self.num_micro_batches = global_batch_size // \ micro_batch_times_data_parallel assert self.num_micro_batches >= 1 def update(self, consumed_samples): pass class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): def __init__(self, start_batch_size, batch_size_increment, ramup_samples, global_batch_size, micro_batch_size, data_parallel_size): """Batch size ramp up. Over steps = (global-batch-size - start-batch-size) / batch_size_increment increment batch size from start-batch-size to global-batch-size using rampup-samples / steps samples. Arguments: start_batch_size: global batch size to start with batch_size_increment: global batch size increments ramup_samples: number of samples to use ramp up global batch size from `start_batch_size` to `global_batch_size` global_batch_size: global batch size post rampup micro_batch_size: micro batch size data_parallel_size: data parallel size. """ self.micro_batch_size = micro_batch_size self.data_parallel_size = data_parallel_size self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ self.data_parallel_size assert self.micro_batch_times_data_parallel_size > 0 assert start_batch_size > 0 self.start_batch_size = start_batch_size assert global_batch_size > 0 self.global_batch_size = global_batch_size diff_batch_size = self.global_batch_size - self.start_batch_size assert diff_batch_size >= 0 assert batch_size_increment > 0 self.batch_size_increment = batch_size_increment assert diff_batch_size % batch_size_increment == 0, 'expected ' \ 'global batch size interval ({}) to be divisible by global batch ' \ 'size increment ({})'.format(diff_batch_size, batch_size_increment) num_increments = diff_batch_size // self.batch_size_increment self.ramup_samples = ramup_samples assert self.ramup_samples >= 0 self.rampup_samples_per_increment = self.ramup_samples / num_increments # Initialize number of microbatches. self.update(0) def update(self, consumed_samples): if consumed_samples > self.ramup_samples: current_global_batch_size = self.global_batch_size else: steps = int(consumed_samples / self.rampup_samples_per_increment) current_global_batch_size = self.start_batch_size + \ steps * self.batch_size_increment assert current_global_batch_size <= self.global_batch_size assert current_global_batch_size % \ self.micro_batch_times_data_parallel_size == 0, 'current global ' \ 'batch size ({}) is not divisible by micro-batch-size ({}) times' \ 'data parallel size ({})'.format(current_global_batch_size, self.micro_batch_size, self.data_parallel_size) self.num_micro_batches = current_global_batch_size // \ self.micro_batch_times_data_parallel_size def _build_tokenizer(args): """Initialize tokenizer.""" global _GLOBAL_TOKENIZER _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') _GLOBAL_TOKENIZER = build_tokenizer(args) return _GLOBAL_TOKENIZER def rebuild_tokenizer(args): global _GLOBAL_TOKENIZER _GLOBAL_TOKENIZER = None return _build_tokenizer(args) def _set_tensorboard_writer(args): """Set tensorboard writer.""" global _GLOBAL_TENSORBOARD_WRITER _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer') if hasattr(args, 'tensorboard_dir') and \ args.tensorboard_dir and args.rank == 0: try: from torch.utils.tensorboard import SummaryWriter print('> setting tensorboard ...') _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( log_dir=args.tensorboard_dir) except ModuleNotFoundError: print('WARNING: TensorBoard writing requested but is not ' 'available (are you using PyTorch 1.1.0 or later?), ' 'no TensorBoard logs will be written.', flush=True) def _set_adlr_autoresume(args): """Initialize ADLR autoresume.""" global _GLOBAL_ADLR_AUTORESUME _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') if args.adlr_autoresume: if args.rank == 0: print('enabling autoresume ...', flush=True) sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) try: from userlib.auto_resume import AutoResume except BaseException: print('ADLR autoresume is not available, exiting ...') sys.exit() _GLOBAL_ADLR_AUTORESUME = AutoResume def _set_timers(): """Initialize timers.""" global _GLOBAL_TIMERS _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _GLOBAL_TIMERS = Timers() def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" assert var is not None, '{} is not initialized.'.format(name) def _ensure_var_is_not_initialized(var, name): """Make sure the input variable is not None.""" assert var is None, '{} is already initialized.'.format(name) class _Timer: """Timer.""" def __init__(self, name): self.name_ = name self.elapsed_ = 0.0 self.started_ = False self.start_time = time.time() def start(self): """Start the timer.""" assert not self.started_, 'timer has already been started' torch.cuda.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, 'timer is not started' torch.cuda.synchronize() self.elapsed_ += (time.time() - self.start_time) self.started_ = False def reset(self): """Reset timer.""" self.elapsed_ = 0.0 self.started_ = False def elapsed(self, reset=True): """Calculate the elapsed time.""" started_ = self.started_ # If the timing in progress, end it first. if self.started_: self.stop() # Get the elapsed time. elapsed_ = self.elapsed_ # Reset the elapsed time if reset: self.reset() # If timing was in progress, set it back. if started_: self.start() return elapsed_ class Timers: """Group of timers.""" def __init__(self): self.timers = {} def __call__(self, name): if name not in self.timers: self.timers[name] = _Timer(name) return self.timers[name] def write(self, names, writer, iteration, normalizer=1.0, reset=False): """Write timers to a tensorboard writer""" # currently when using add_scalars, # torch.utils.add_scalars makes each timer its own run, which # polutes the runs list, so we just add each as a scalar assert normalizer > 0.0 for name in names: value = self.timers[name].elapsed(reset=reset) / normalizer writer.add_scalar(name + '_time', value, iteration) def log(self, names, normalizer=1.0, reset=True): """Log a group of timers.""" assert normalizer > 0.0 string = 'time (ms)' for name in names: elapsed_time = self.timers[name].elapsed( reset=reset) * 1000.0 / normalizer string += ' | {}: {:.2f}'.format(name, elapsed_time) if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(string, flush=True) else: print(string, flush=True)