Commit 5c04ceb3 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Implementing lazy parallel initialization


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent cf9bdadc
......@@ -322,7 +322,10 @@ def _add_distributed_args(parser):
'to use.')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() skips DDP initialization'
' and returns function to complete it instead'
'This is for external DDP manager.' )
return parser
......
......@@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables
from .global_vars import set_global_variables
from .mpu import set_model_parallel_rank, set_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
......@@ -34,7 +34,11 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing."""
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only for args.distributed_backend == "external_ddp")
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
......@@ -45,21 +49,37 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed.
_initialize_distributed()
# Autoresume.
_init_autoresume()
# torch.distributed initialization
def ddp_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed()
# Random seeds for reproducibility.
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
if 'lazy_mpu_init' in args:
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size)
# and refurn function for external DDP manager to call when it has DDP initialized
set_model_parallel_rank(args.rank)
return ddp_init
else:
# Megatron's own DDP. Do initialization right away
ddp_init()
# Autoresume.
_init_autoresume()
# Write arguments to tensorboard.
_write_args_to_tensorboard()
# No continuation function
return None
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
......
......@@ -15,6 +15,7 @@
from .distributed import *
from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel
from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model
......@@ -24,7 +24,7 @@ from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
......@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method):
init_method=None, scaled_init_method=None):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -55,6 +55,12 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
elif args.onnx_safe:
gelu = erf_gelu
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func,
......
......@@ -27,9 +27,9 @@ from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_rank, set_model_parallel_rank
from .initialize import get_model_parallel_src_rank
from .initialize import get_model_parallel_world_size
from .initialize import get_model_parallel_world_size, set_model_parallel_world_size
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
......
......@@ -93,16 +93,13 @@ def model_parallel_is_initialized():
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def set_model_parallel_group(group):
"""Set model parallel group."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
......@@ -110,10 +107,6 @@ def get_data_parallel_group():
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def set_data_parallel_group(group):
"""Set data parallel group."""
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = group
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment