Unverified Commit 95acffb9 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

support mpi launcher (#726)

parent 508a322f
...@@ -35,7 +35,11 @@ def _init_dist_pytorch(backend, **kwargs): ...@@ -35,7 +35,11 @@ def _init_dist_pytorch(backend, **kwargs):
def _init_dist_mpi(backend, **kwargs): def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError # TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None): def _init_dist_slurm(backend, port=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment