Unverified Commit 0e831e23 authored by Ammar Ahmad Awan's avatar Ammar Ahmad Awan Committed by GitHub
Browse files

Simplify dist init and only init if needed. (#553)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 6e65c2cc
...@@ -121,25 +121,31 @@ class DeepSpeedEngine(Module): ...@@ -121,25 +121,31 @@ class DeepSpeedEngine(Module):
self.loaded_checkpoint_dp_world_size = None self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True self.enable_backward_allreduce = True
self.progressive_layer_drop = None self.progressive_layer_drop = None
self.dist_backend = "nccl"
if dist_init_required is None: if dist_init_required is None:
dist_init_required = not dist.is_initialized() dist_init_required = not dist.is_initialized()
if self._in_aml(): if dist_init_required is False:
self._set_environment_variables_for_nccl_backend(args) assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
self._mpi_check(args, dist_init_required)
self.dist_backend = "nccl" # DeepSpeed will initialize torch distributed only if the user has not already intialized it.
if dist_init_required: if dist_init_required and not dist.is_initialized():
if not dist.is_initialized(): # discover using mpi4py if user specifies the flag
logger.info("Initializing torch distributed with backend: {}".format( if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
self.dist_backend)) # if in Azure ML environment and user specified this flag, notify the user to remove the flag.
dist.init_process_group(backend=self.dist_backend) if self._in_aml():
logger.warning(
"Please remove the --deepspeed_mpi flag if running on AzureML.")
self._mpi_check(args, dist_init_required)
else: else:
logger.warning( # detect if we are in Azure ML environment
"Was given dist_init_required=True but detected that torch" if self._in_aml():
"distributed was already initialized, cannot initialize twice.") self._set_environment_variables_for_nccl_backend(args)
logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
self._do_args_sanity_check(args) self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu) self._configure_with_arguments(args, mpu)
...@@ -203,7 +209,7 @@ class DeepSpeedEngine(Module): ...@@ -203,7 +209,7 @@ class DeepSpeedEngine(Module):
self.unflatten = util_ops.unflatten self.unflatten = util_ops.unflatten
def _in_aml(self): def _in_aml(self):
# read and environment variable to detect if we are using an Azure ML environment # read AzureML environment variable to detect if we are using an Azure ML environment
if 'AZUREML_EXPERIMENT_ID' in os.environ: if 'AZUREML_EXPERIMENT_ID' in os.environ:
return True return True
else: else:
...@@ -246,43 +252,42 @@ class DeepSpeedEngine(Module): ...@@ -246,43 +252,42 @@ class DeepSpeedEngine(Module):
os.environ['MASTER_PORT'])) os.environ['MASTER_PORT']))
def _mpi_check(self, args, dist_init_required): def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: from mpi4py import MPI
from mpi4py import MPI import subprocess
import subprocess comm = MPI.COMM_WORLD
comm = MPI.COMM_WORLD rank = comm.Get_rank()
rank = comm.Get_rank() world_size = comm.Get_size()
world_size = comm.Get_size()
master_addr = None
master_addr = None if rank == 0:
if rank == 0: hostname_cmd = ["hostname -I"]
hostname_cmd = ["hostname -I"] result = subprocess.check_output(hostname_cmd, shell=True)
result = subprocess.check_output(hostname_cmd, shell=True) master_addr = result.decode('utf-8').split()[0]
master_addr = result.decode('utf-8').split()[0] master_addr = comm.bcast(master_addr, root=0)
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
# Determine local rank by assuming hostnames are unique proc_name = MPI.Get_processor_name()
proc_name = MPI.Get_processor_name() all_procs = comm.allgather(proc_name)
all_procs = comm.allgather(proc_name) local_rank = sum([i == proc_name for i in all_procs[:rank]])
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['RANK'] = str(rank) os.environ['WORLD_SIZE'] = str(world_size)
os.environ['WORLD_SIZE'] = str(world_size) args.local_rank = local_rank
args.local_rank = local_rank os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT
os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT
logger.info(
logger.info( "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" .format(os.environ['RANK'],
.format(os.environ['RANK'], args.local_rank,
args.local_rank, os.environ['WORLD_SIZE'],
os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']))
os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
if not dist_init_required and dist.is_initialized(): assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( world_size, dist.get_world_size())
world_size, dist.get_world_size())
def pld_enabled(self): def pld_enabled(self):
return self._config.pld_enabled return self._config.pld_enabled
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment