Commit 5448ca25 authored by Raul Puri's avatar Raul Puri
Browse files

Added logic for initilialize_megatron to work with cpu only jobs. This is...

Added logic for initilialize_megatron to work with cpu only jobs. This is necessary for several evaluation and processing scripts in downstream repos.
parent 1363e503
......@@ -59,6 +59,7 @@ def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
device_count = 0
if torch.distributed.is_initialized():
if args.rank == 0:
......@@ -66,6 +67,7 @@ def _initialize_distributed():
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
if device_count > 0:
device = torch.cuda.current_device()
local_rank = args.rank % torch.cuda.device_count()
assert local_rank == device, \
......@@ -76,6 +78,7 @@ def _initialize_distributed():
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
assert args.local_rank == device, \
......@@ -94,6 +97,7 @@ def _initialize_distributed():
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
......@@ -112,6 +116,7 @@ def _set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment