Commit 97e07d64 authored by Mohammad's avatar Mohammad
Browse files

Merge branch 'master' into remove_local_ddp_bcast

parents 41c1af0e eb0a8bf0
...@@ -19,7 +19,8 @@ import argparse ...@@ -19,7 +19,8 @@ import argparse
import os import os
def parse_args(extra_args_provider=None, defaults={}): def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments.""" """Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments') parser = argparse.ArgumentParser(description='Megatron-LM Arguments')
...@@ -41,7 +42,10 @@ def parse_args(extra_args_provider=None, defaults={}): ...@@ -41,7 +42,10 @@ def parse_args(extra_args_provider=None, defaults={}):
parser = extra_args_provider(parser) parser = extra_args_provider(parser)
# Parse. # Parse.
args = parser.parse_args() if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
......
...@@ -61,22 +61,26 @@ def get_timers(): ...@@ -61,22 +61,26 @@ def get_timers():
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}): def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults) defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
def _parse_args(extra_args_provider=None, defaults={}): def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments.""" """Parse entire arguments."""
global _GLOBAL_ARGS global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider, _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults) defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS return _GLOBAL_ARGS
......
...@@ -28,7 +28,8 @@ from megatron import mpu ...@@ -28,7 +28,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}): def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds.""" set autoresume and random seeds."""
# Make sure cuda is available. # Make sure cuda is available.
...@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}): ...@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed. # Pytorch distributed.
_initialize_distributed() _initialize_distributed()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment