Commit c632d205 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Addressing code review comments


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent be842037
......@@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from .global_vars import set_global_variables
from .mpu import set_model_parallel_rank, set_model_parallel_world_size
from megatron.global_vars import set_global_variables
from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
......@@ -50,7 +50,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=ignore_unknown_args)
# torch.distributed initialization
def ddp_init():
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed()
......@@ -61,16 +61,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
_set_random_seed(args.seed)
args = get_args()
if 'lazy_mpu_init' in args:
if args.lazy_mpu_init:
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size)
# and refurn function for external DDP manager to call when it has DDP initialized
set_model_parallel_rank(args.rank)
return ddp_init
return finish_mpu_init
else:
# Megatron's own DDP. Do initialization right away
ddp_init()
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()
# Autoresume.
_init_autoresume()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment