"vscode:/vscode.git/clone" did not exist on "0332624ee069aa3dea97804acd8575c21a0cde2a"
Unverified Commit 7e3509bb authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

if not init torch distributed in mpi mode, make sure they match (#112)

parent fafc62fe
......@@ -119,11 +119,11 @@ class DeepSpeedLight(Module):
self.gradient_average = True
self.warn_unscaled_loss = True
self._mpi_check(args)
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
self._mpi_check(args, dist_init_required)
self.dist_backend = "nccl"
if dist_init_required:
if not dist.is_initialized():
......@@ -186,7 +186,7 @@ class DeepSpeedLight(Module):
if self.dump_state():
print_configuration(self, 'DeepSpeedLight')
def _mpi_check(self, args):
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
import subprocess
......@@ -220,6 +220,10 @@ class DeepSpeedLight(Module):
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(world_size, dist.get_world_size())
def tensorboard_enabled(self):
return self._config.tensorboard_enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment