Commit 87bd026a authored by Raul Puri's avatar Raul Puri
Browse files

Update initialize.py

parent 5448ca25
...@@ -59,7 +59,7 @@ def _initialize_distributed(): ...@@ -59,7 +59,7 @@ def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and mpu."""
args = get_args() args = get_args()
device_count = 0 device_count = torch.cuda.device_count()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if args.rank == 0: if args.rank == 0:
...@@ -69,7 +69,7 @@ def _initialize_distributed(): ...@@ -69,7 +69,7 @@ def _initialize_distributed():
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
if device_count > 0: if device_count > 0:
device = torch.cuda.current_device() device = torch.cuda.current_device()
local_rank = args.rank % torch.cuda.device_count() local_rank = args.rank % device_count
assert local_rank == device, \ assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.' 'expected local-rank to be the same as rank % device-count.'
...@@ -79,7 +79,7 @@ def _initialize_distributed(): ...@@ -79,7 +79,7 @@ def _initialize_distributed():
print('> initializing torch distributed ...', flush=True) print('> initializing torch distributed ...', flush=True)
# Manually set the device ids. # Manually set the device ids.
if device_count > 0: if device_count > 0:
device = args.rank % torch.cuda.device_count() device = args.rank % device_count
if args.local_rank is not None: if args.local_rank is not None:
assert args.local_rank == device, \ assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.' 'expected local-rank to be the same as rank % device-count.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment