"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8de78001df95a641bf6ef942bee9553921d44490"
Unverified Commit 7e3509bb authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

if not init torch distributed in mpi mode, make sure they match (#112)

parent fafc62fe
...@@ -119,11 +119,11 @@ class DeepSpeedLight(Module): ...@@ -119,11 +119,11 @@ class DeepSpeedLight(Module):
self.gradient_average = True self.gradient_average = True
self.warn_unscaled_loss = True self.warn_unscaled_loss = True
self._mpi_check(args)
if dist_init_required is None: if dist_init_required is None:
dist_init_required = not dist.is_initialized() dist_init_required = not dist.is_initialized()
self._mpi_check(args, dist_init_required)
self.dist_backend = "nccl" self.dist_backend = "nccl"
if dist_init_required: if dist_init_required:
if not dist.is_initialized(): if not dist.is_initialized():
...@@ -186,7 +186,7 @@ class DeepSpeedLight(Module): ...@@ -186,7 +186,7 @@ class DeepSpeedLight(Module):
if self.dump_state(): if self.dump_state():
print_configuration(self, 'DeepSpeedLight') print_configuration(self, 'DeepSpeedLight')
def _mpi_check(self, args): def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI from mpi4py import MPI
import subprocess import subprocess
...@@ -220,6 +220,10 @@ class DeepSpeedLight(Module): ...@@ -220,6 +220,10 @@ class DeepSpeedLight(Module):
os.environ['MASTER_ADDR'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT'])) os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(world_size, dist.get_world_size())
def tensorboard_enabled(self): def tensorboard_enabled(self):
return self._config.tensorboard_enabled return self._config.tensorboard_enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment