Commit 7890681a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'staging_distributed_cpu_fix' into 'staging'

Added logic for initilialize_megatron to work with cpu only jobs. This is...

See merge request ADLR/megatron-lm!46
parents 1363e503 87bd026a
...@@ -59,6 +59,7 @@ def _initialize_distributed(): ...@@ -59,6 +59,7 @@ def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and mpu."""
args = get_args() args = get_args()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if args.rank == 0: if args.rank == 0:
...@@ -66,23 +67,25 @@ def _initialize_distributed(): ...@@ -66,23 +67,25 @@ def _initialize_distributed():
'skipping initialization ...', flush=True) 'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
device = torch.cuda.current_device() if device_count > 0:
local_rank = args.rank % torch.cuda.device_count() device = torch.cuda.current_device()
assert local_rank == device, \ local_rank = args.rank % device_count
'expected local-rank to be the same as rank % device-count.' assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else: else:
if args.rank == 0: if args.rank == 0:
print('> initializing torch distributed ...', flush=True) print('> initializing torch distributed ...', flush=True)
# Manually set the device ids. # Manually set the device ids.
device = args.rank % torch.cuda.device_count() if device_count > 0:
if args.local_rank is not None: device = args.rank % device_count
assert args.local_rank == device, \ if args.local_rank is not None:
'expected local-rank to be the same as rank % device-count.' assert args.local_rank == device, \
else: 'expected local-rank to be the same as rank % device-count.'
args.local_rank = device else:
torch.cuda.set_device(device) args.local_rank = device
torch.cuda.set_device(device)
# Call the init process # Call the init process
init_method = 'tcp://' init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost') master_ip = os.getenv('MASTER_ADDR', 'localhost')
...@@ -94,7 +97,8 @@ def _initialize_distributed(): ...@@ -94,7 +97,8 @@ def _initialize_distributed():
init_method=init_method) init_method=init_method)
# Set the model-parallel / data-parallel communicators. # Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size) if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
def _init_autoresume(): def _init_autoresume():
...@@ -112,7 +116,8 @@ def _set_random_seed(seed): ...@@ -112,7 +116,8 @@ def _set_random_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed) if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment