"...text-generation-inference.git" did not exist on "8bd0adb1356d5c6e0238abe64a6cf94b50140db4"
Commit c632d205 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Addressing code review comments


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent be842037
...@@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume ...@@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from .global_vars import set_global_variables from megatron.global_vars import set_global_variables
from .mpu import set_model_parallel_rank, set_model_parallel_world_size from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -50,7 +50,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -50,7 +50,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
# torch.distributed initialization # torch.distributed initialization
def ddp_init(): def finish_mpu_init():
args = get_args() args = get_args()
# Pytorch distributed. # Pytorch distributed.
_initialize_distributed() _initialize_distributed()
...@@ -61,16 +61,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -61,16 +61,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
_set_random_seed(args.seed) _set_random_seed(args.seed)
args = get_args() args = get_args()
if 'lazy_mpu_init' in args: if args.lazy_mpu_init:
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_model_parallel_world_size(args.model_parallel_size) set_model_parallel_world_size(args.model_parallel_size)
# and refurn function for external DDP manager to call when it has DDP initialized # and refurn function for external DDP manager to call when it has DDP initialized
set_model_parallel_rank(args.rank) set_model_parallel_rank(args.rank)
return ddp_init return finish_mpu_init
else: else:
# Megatron's own DDP. Do initialization right away # Megatron's MPU is the master. Complete initialization right away.
ddp_init() finish_mpu_init()
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment