Commit 94e2ca57 authored by Mohammad's avatar Mohammad
Browse files

arguments.py refactored

parent a9e19f8e
...@@ -13,53 +13,81 @@ ...@@ -13,53 +13,81 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""argparser configuration""" """Megatron arguments."""
import argparse import argparse
import os import os
import torch
_GLOBAL_ARGS = None def parse_args(extra_args_provider=None, defaults={}):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments')
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
args = parser.parse_args()
def _print_args(): # Set input defaults.
"""Print arguments.""" for key in defaults:
setattr(args, key, defaults[key])
args = get_args() # Distributed args.
writer = get_tensorboard_writer() args.rank = int(os.getenv('RANK', '0'))
print_rank_0('arguments:') args.world_size = int(os.getenv("WORLD_SIZE", '1'))
str_list = [] args.model_parallel_size = min(args.model_parallel_size, args.world_size)
for arg in vars(args): if args.rank == 0:
dots = '.' * (29 - len(arg)) print('using world size: {} and model-parallel size: {} '.format(
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) args.world_size, args.model_parallel_size))
if writer:
writer.add_text(arg, str(getattr(args, arg)))
for arg in sorted(str_list, key= lambda x: x.lower()):
print_rank_0(arg)
# Fp16 loss scaling.
args.dynamic_loss_scale = False
if args.loss_scale is None:
args.dynamic_loss_scale = True
def parse_args(extra_args_provider=None):
global _GLOBAL_ARGS # Checks.
assert _GLOBAL_ARGS is None, 'args already initializeed' assert args.hidden_size % args.num_attention_heads == 0
_GLOBAL_ARGS = get_args_(extra_args_provider=extra_args_provider) assert args.max_position_embeddings >= args.seq_length
return _GLOBAL_ARGS assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
_print_args(args)
return args
def get_args(extra_args_provider=None):
global _GLOBAL_ARGS def _print_args(args):
if _GLOBAL_ARGS is None: """Print arguments."""
return parse_args(extra_args_provider=extra_args_provider) if args.rank == 0:
else: print('-------------------- arguments --------------------', flush=True)
return _GLOBAL_ARGS str_list = []
for arg in vars(args):
dots = '.' * (32 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('---------------- end of arguments ----------------', flush=True)
def add_network_size_args(parser): def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size') group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, required=True, group.add_argument('--num-layers', type=int, required=True,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, required=True, group.add_argument('--hidden-size', type=int, required=True,
...@@ -72,11 +100,13 @@ def add_network_size_args(parser): ...@@ -72,11 +100,13 @@ def add_network_size_args(parser):
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.' help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.') 'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
return parser return parser
def add_regularization_args(parser): def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization') group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1, group.add_argument('--attention-dropout', type=float, default=0.1,
...@@ -89,9 +119,9 @@ def add_regularization_args(parser): ...@@ -89,9 +119,9 @@ def add_regularization_args(parser):
help='Gradient clipping based on global L2 norm.') help='Gradient clipping based on global L2 norm.')
return parser return parser
def add_training_args(parser):
def _add_training_args(parser):
group = parser.add_argument_group(title='training') group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, required=True, group.add_argument('--batch-size', type=int, required=True,
...@@ -103,7 +133,7 @@ def add_training_args(parser): ...@@ -103,7 +133,7 @@ def add_training_args(parser):
'with larger models, sequences, and batch sizes.') 'with larger models, sequences, and batch sizes.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.') help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, required=True, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs.') 'training runs.')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
...@@ -117,7 +147,7 @@ def add_training_args(parser): ...@@ -117,7 +147,7 @@ def add_training_args(parser):
return parser return parser
def add_initialization_args(parser): def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization') group = parser.add_argument_group(title='initialization')
group.add_argument('--seed', type=int, default=1234, group.add_argument('--seed', type=int, default=1234,
...@@ -126,11 +156,11 @@ def add_initialization_args(parser): ...@@ -126,11 +156,11 @@ def add_initialization_args(parser):
group.add_argument('--init-method-std', type=float, default=0.02, group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal ' help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.') 'distribution used for weight initialization.')
return parser return parser
def add_learning_rate_args(parser): def _add_learning_rate_args(parser):
group = parser.add_argument_group(title='learning rate') group = parser.add_argument_group(title='learning rate')
group.add_argument('--lr', type=float, required=True, group.add_argument('--lr', type=float, required=True,
...@@ -164,7 +194,7 @@ def add_learning_rate_args(parser): ...@@ -164,7 +194,7 @@ def add_learning_rate_args(parser):
return parser return parser
def add_checkpointing_args(parser): def _add_checkpointing_args(parser):
group = parser.add_argument_group(title='checkpointing') group = parser.add_argument_group(title='checkpointing')
group.add_argument('--save', type=str, default=None, group.add_argument('--save', type=str, default=None,
...@@ -189,7 +219,7 @@ def add_checkpointing_args(parser): ...@@ -189,7 +219,7 @@ def add_checkpointing_args(parser):
return parser return parser
def add_mixed_precision_args(parser): def _add_mixed_precision_args(parser):
group = parser.add_argument_group(title='mixed precision') group = parser.add_argument_group(title='mixed precision')
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
...@@ -214,7 +244,7 @@ def add_mixed_precision_args(parser): ...@@ -214,7 +244,7 @@ def add_mixed_precision_args(parser):
return parser return parser
def add_distributed_args(parser): def _add_distributed_args(parser):
group = parser.add_argument_group(title='mixed precision') group = parser.add_argument_group(title='mixed precision')
group.add_argument('--model-parallel-size', type=int, default=1, group.add_argument('--model-parallel-size', type=int, default=1,
...@@ -223,7 +253,7 @@ def add_distributed_args(parser): ...@@ -223,7 +253,7 @@ def add_distributed_args(parser):
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
group.add_argument('--DDP-impl', default='local', group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
...@@ -232,7 +262,7 @@ def add_distributed_args(parser): ...@@ -232,7 +262,7 @@ def add_distributed_args(parser):
return parser return parser
def add_validation_args(parser): def _add_validation_args(parser):
group = parser.add_argument_group(title='validation') group = parser.add_argument_group(title='validation')
group.add_argument('--eval-iters', type=int, default=100, group.add_argument('--eval-iters', type=int, default=100,
...@@ -245,12 +275,12 @@ def add_validation_args(parser): ...@@ -245,12 +275,12 @@ def add_validation_args(parser):
return parser return parser
def add_data_args(parser): def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader') group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', type=str, required=True, group.add_argument('--data-path', type=str, default=None,
help='Path to combined dataset to split.') help='Path to combined dataset to split.')
group.add_argument('--split', type=str, required=True, group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
'`90,5,5` will use 90% of data for training, 5% for ' '`90,5,5` will use 90% of data for training, 5% for '
...@@ -267,59 +297,31 @@ def add_data_args(parser): ...@@ -267,59 +297,31 @@ def add_data_args(parser):
help='Warm up mmap files.') help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2, group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.") help="Dataloader number of workers.")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
parser.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
return parser return parser
########################
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--pretrained-bert', action='store_true',
help='use a pretrained bert-large-uncased model instead'
'of initializing from scratch. See '
'--tokenizer-model-type to specify which pretrained '
'BERT model to use')
group.add_argument('--intermediate-size', type=int, default=None,
help='transformer embedding dimension for FFN'
'set to 4*`--hidden-size` if it is None')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='layer norm epsilon')
group.add_argument('--deep-init', action='store_true',
help='initialize bert model similar to gpt2 model.'
'scales initialization of projection layers by a '
'factor of 1/sqrt(2N). Necessary to train bert '
'models larger than BERT-Large.')
group.add_argument('--vocab-size', type=int, default=None,
help='vocabulary size to use for non-character-level '
'tokenization. This value will only be used when '
'creating a tokenizer')
return parser
def _add_autoresume_args(parser):
group = parser.add_argument_group(title='autoresume')
def add_fp16_config_args(parser): group.add_argument('--adlr-autoresume', action='store_true',
"""Mixed precision arguments.""" help='Enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='Intervals over which check for autoresume'
'termination signal')
group = parser.add_argument_group('fp16', 'fp16 configurations') return parser
group.add_argument('--fp32-embedding', action='store_true',
help='embedding in fp32')
group.add_argument('--fp32-layernorm', action='store_true',
help='layer norm in fp32')
group.add_argument('--fp32-tokentypes', action='store_true',
help='embedding token types in fp32')
group.add_argument('--fp32-allreduce', action='store_true',
help='all-reduce in fp32')
return parser ########################################################################
def add_training_args_(parser): def add_training_args_(parser):
...@@ -336,15 +338,6 @@ def add_training_args_(parser): ...@@ -336,15 +338,6 @@ def add_training_args_(parser):
group.add_argument('--eod-mask-loss', action='store_true', group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens') help='Mask loss for the end of document tokens')
# Learning rate.
# autoresume
group.add_argument('--adlr-autoresume', action='store_true',
help='enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='intervals over which check for autoresume'
'termination signal')
return parser return parser
...@@ -408,9 +401,6 @@ def add_data_args_(parser): ...@@ -408,9 +401,6 @@ def add_data_args_(parser):
group = parser.add_argument_group('data', 'data configurations') group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--data-loader', type=str, default=None, group.add_argument('--data-loader', type=str, default=None,
choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'], choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
help='Which data loader to use. Default varies by model.') help='Which data loader to use. Default varies by model.')
...@@ -423,137 +413,10 @@ def add_data_args_(parser): ...@@ -423,137 +413,10 @@ def add_data_args_(parser):
group.add_argument('--test-data', nargs='*', default=None, group.add_argument('--test-data', nargs='*', default=None,
help='path(s) to the testing data.') help='path(s) to the testing data.')
group.add_argument('--max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use per sequence.'
'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
'MUST BE SPECIFIED IF `--data-loader tfrecords`.')
# arguments for binary data loader # arguments for binary data loader
parser.add_argument('--data-impl', type=str, default='infer',
help='implementation of indexed datasets',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for, defaults to total iters * batch-size.')
parser.add_argument('--data-epochs', type=int, default=None,
help='Number of epochs to plan for, defaults to using --max-num-samples')
# arguments for numpy data loader # arguments for numpy data loader
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='the filename containing all the shards sizes for numpy data loader') help='the filename containing all the shards sizes for numpy data loader')
# arguments for raw/tfrecords data loader
group.add_argument('--delim', default=',',
help='delimiter used to parse csv data files')
group.add_argument('--text-key', default='sentence',
help='key to use to extract text from json/csv')
group.add_argument('--eval-text-key', default=None,
help='key to use to extract text from '
'json/csv evaluation datasets')
group.add_argument('--loose-json', action='store_true',
help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one '
'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--tokenizer-model-type', type=str,
default='bert-large-uncased',
help="Model type to use for sentencepiece tokenization \
(one of ['bpe', 'char', 'unigram', 'word']) or \
bert vocab to use for BertWordPieceTokenizer (one of \
['bert-large-uncased', 'bert-large-cased', etc.])")
group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
help='path used to save/load sentencepiece tokenization '
'models')
group.add_argument('--tokenizer-type', type=str,
default='BertWordPieceLowerCase',
choices=['CharacterLevelTokenizer',
'SentencePieceTokenizer',
'BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='what type of tokenizer to use')
group.add_argument("--cache-dir", default=None, type=str,
help="Where to store pre-trained BERT downloads")
return parser return parser
def get_args_(extra_args_provider=None):
"""Parse all the args."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments')
parser = add_network_size_args(parser)
parser = add_regularization_args(parser)
parser = add_training_args(parser)
parser = add_initialization_args(parser)
parser = add_learning_rate_args(parser)
parser = add_checkpointing_args(parser)
parser = add_mixed_precision_args(parser)
parser = add_distributed_args(parser)
parser = add_validation_args(parser)
parser = add_data_args(parser)
#parser.print_help()
#exit()
parser = add_model_config_args(parser)
parser = add_fp16_config_args(parser)
parser = add_training_args_(parser)
parser = add_evaluation_args(parser)
parser = add_text_generate_args(parser)
parser = add_data_args_(parser)
if extra_args_provider is not None:
parser = extra_args_provider(parser)
args = parser.parse_args()
# Checks.
if args.save is not None:
assert args.save_interval is not None, \
'expected \'--save-interval\' in the input arguments.'
if not args.train_data and not args.data_path:
print('WARNING: No training data specified')
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
# We are using (OpenMPI) mpirun for launching distributed data parallel processes
local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
# Possibly running with Slurm
num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
nodeid = int(os.getenv('SLURM_NODEID', '0'))
args.local_rank = local_rank
args.rank = nodeid*local_size + local_rank
args.world_size = num_nodes*local_size
args.model_parallel_size = min(args.model_parallel_size, args.world_size)
if args.rank == 0:
print('using world size: {} and model-parallel size: {} '.format(
args.world_size, args.model_parallel_size))
args.dynamic_loss_scale = False
if args.loss_scale is None:
args.dynamic_loss_scale = True
if args.rank == 0:
print(' > using dynamic loss scaling')
# The args fp32_* or fp16_* meant to be active when the
# args fp16 is set. So the default behaviour should all
# be false.
if not args.fp16:
args.fp32_embedding = False
args.fp32_tokentypes = False
args.fp32_layernorm = False
return args
...@@ -18,18 +18,13 @@ ...@@ -18,18 +18,13 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from megatron.arguments import get_args
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
def build_tokenizer(): def build_tokenizer(args):
"""Initialize tokenizer.""" """Initialize tokenizer."""
# Retrieve args.
args = get_args()
if args.rank == 0: if args.rank == 0:
print('building {} tokenizer ...'.format(args.tokenizer_type), print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True) flush=True)
# Select and instantiate the tokenizer. # Select and instantiate the tokenizer.
...@@ -41,16 +36,16 @@ def build_tokenizer(): ...@@ -41,16 +36,16 @@ def build_tokenizer():
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
# Add vocab size. # Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
return tokenizer return tokenizer
def _vocab_size_with_padding(orig_vocab_size): def _vocab_size_with_padding(orig_vocab_size, args):
"""Pad vocab size so it is divisible by model parallel size and """Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size.""" still having GPU friendly size."""
args = get_args()
after = orig_vocab_size after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \ multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size args.model_parallel_size
......
...@@ -59,36 +59,38 @@ def get_timers(): ...@@ -59,36 +59,38 @@ def get_timers():
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None): def set_global_variables(extra_args_provider=None, args_defaults={}):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
_parse_args(extra_args_provider=extra_args_provider) args = _parse_args(extra_args_provider=extra_args_provider,
_build_tokenizer() defaults=args_defaults)
_set_tensorboard_writer() _build_tokenizer(args)
_set_adlr_autoresume() _set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers() _set_timers()
def _parse_args(extra_args_provider=None): def _parse_args(extra_args_provider=None, defaults={}):
"""Parse entire arguments.""" """Parse entire arguments."""
global _GLOBAL_ARGS global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider) _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults)
return _GLOBAL_ARGS
def _build_tokenizer(): def _build_tokenizer(args):
"""Initialize tokenizer.""" """Initialize tokenizer."""
global _GLOBAL_TOKENIZER global _GLOBAL_TOKENIZER
_ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
_GLOBAL_TOKENIZER = build_tokenizer() _GLOBAL_TOKENIZER = build_tokenizer(args)
def _set_tensorboard_writer(): def _set_tensorboard_writer(args):
"""Set tensorboard writer.""" """Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer') 'tensorboard writer')
args = get_args()
if hasattr(args, 'tensorboard_dir') and \ if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0: args.tensorboard_dir and args.rank == 0:
try: try:
...@@ -102,12 +104,11 @@ def _set_tensorboard_writer(): ...@@ -102,12 +104,11 @@ def _set_tensorboard_writer():
'no TensorBoard logs will be written.', flush=True) 'no TensorBoard logs will be written.', flush=True)
def _set_adlr_autoresume(): def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume.""" """Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME global _GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
args = get_args()
if args.adlr_autoresume: if args.adlr_autoresume:
if args.rank == 0: if args.rank == 0:
print('enabling autoresume ...', flush=True) print('enabling autoresume ...', flush=True)
......
...@@ -24,15 +24,20 @@ import torch ...@@ -24,15 +24,20 @@ import torch
from megatron import mpu from megatron import mpu
from .global_vars import get_adlr_autoresume from .global_vars import get_adlr_autoresume
from .global_vars import get_args from .global_vars import get_args
from .global_vars import get_tensorboard_writer
from .global_vars import set_global_variables from .global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None): def initialize_megatron(extra_args_provider=None, args_defaults={}):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds.""" set autoresume and random seeds."""
# Male sure cuda is avaiable.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider) set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Pytorch distributed. # Pytorch distributed.
_initialize_distributed() _initialize_distributed()
...@@ -46,6 +51,9 @@ def initialize_megatron(extra_args_provider=None): ...@@ -46,6 +51,9 @@ def initialize_megatron(extra_args_provider=None):
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
def _initialize_distributed(): def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and mpu."""
...@@ -107,3 +115,12 @@ def _set_random_seed(seed): ...@@ -107,3 +115,12 @@ def _set_random_seed(seed):
mpu.model_parallel_cuda_manual_seed(seed) mpu.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
def _write_args_to_tensorboard():
"""Write arguments to tensorboard."""
args = get_args()
writer = get_tensorboard_writer()
if writer:
for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)))
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron.arguments import get_args from megatron.arguments import parse_args
from megatron import mpu from megatron import mpu
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -72,7 +72,7 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -72,7 +72,7 @@ def run(top_level_message, train_val_test_data_provider,
""" """
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
args = get_args(extra_args_provider=extra_args_provider) args = parse_args(extra_args_provider=extra_args_provider)
timers, writer = initialize_megatron(top_level_message, args) timers, writer = initialize_megatron(top_level_message, args)
# Data stuff. # Data stuff.
......
...@@ -196,5 +196,11 @@ def get_train_val_test_data(args): ...@@ -196,5 +196,11 @@ def get_train_val_test_data(args):
if __name__ == "__main__": if __name__ == "__main__":
'''
from megatron.initialize import initialize_megatron
initialize_megatron(args_defaults={
'tokenizer_type': 'BertWordPieceLowerCase'})
exit()
'''
run('Pretrain BERT model', get_train_val_test_data, run('Pretrain BERT model', get_train_val_test_data,
model_provider, forward_step) model_provider, forward_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment