Unverified Commit 7e3509bb authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

if not init torch distributed in mpi mode, make sure they match (#112)

parent fafc62fe
......@@ -119,11 +119,11 @@ class DeepSpeedLight(Module):
self.gradient_average = True
self.warn_unscaled_loss = True
self._mpi_check(args)
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
self._mpi_check(args, dist_init_required)
self.dist_backend = "nccl"
if dist_init_required:
if not dist.is_initialized():
......@@ -186,7 +186,7 @@ class DeepSpeedLight(Module):
if self.dump_state():
print_configuration(self, 'DeepSpeedLight')
def _mpi_check(self, args):
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
import subprocess
......@@ -220,6 +220,10 @@ class DeepSpeedLight(Module):
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(world_size, dist.get_world_size())
def tensorboard_enabled(self):
return self._config.tensorboard_enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment