Unverified Commit 8b129030 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Bring back PartialState DeepSpeed (#22921)

* Bring back deepspeed integration

* Branchname

* Self-scheduled

* newline

* Use deepspeed env var

* Remove comment

* Del env var after partialstate
parent 4331923b
...@@ -1544,39 +1544,26 @@ class TrainingArguments: ...@@ -1544,39 +1544,26 @@ class TrainingArguments:
self._n_gpu = 1 self._n_gpu = 1
torch.cuda.set_device(device) torch.cuda.set_device(device)
elif self.deepspeed: elif self.deepspeed:
# deepspeed inits torch.distributed internally # Need to do similar for Accelerator init
from .deepspeed import is_deepspeed_available os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
if not is_deepspeed_available(): del os.environ["ACCELERATE_USE_DEEPSPEED"]
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed
deepspeed.init_distributed(timeout=timedelta(seconds=self.ddp_timeout))
# workaround for setups like notebooks where the launcher can't be used,
# but deepspeed requires a dist env.
# env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1 self._n_gpu = 1
else: else:
self.distributed_state = PartialState(backend=self.xpu_backend) self.distributed_state = PartialState(backend=self.xpu_backend)
self._n_gpu = 1 self._n_gpu = 1
if not is_sagemaker_mp_enabled() and not self.deepspeed: if not is_sagemaker_mp_enabled():
device = self.distributed_state.device device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index self.local_rank = self.distributed_state.local_process_index
if ( if (
torch.distributed.is_available() torch.distributed.is_available()
and torch.distributed.is_initialized() and torch.distributed.is_initialized()
and hasattr(self, "distributed_state")
and self.distributed_state.distributed_type == DistributedType.NO and self.distributed_state.distributed_type == DistributedType.NO
): ):
logger.warning( logger.warning(
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
) )
if not self.deepspeed:
if is_torch_tpu_available(): if is_torch_tpu_available():
device = self.distributed_state.device device = self.distributed_state.device
self._n_gpu = 0 self._n_gpu = 0
...@@ -1615,7 +1602,6 @@ class TrainingArguments: ...@@ -1615,7 +1602,6 @@ class TrainingArguments:
# trigger an error that a device index is missing. Index 0 takes into account the # trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1 # will use the first GPU in that env, i.e. GPU#1
# device = self.distributed_state.device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value. # the default value.
...@@ -1664,7 +1650,7 @@ class TrainingArguments: ...@@ -1664,7 +1650,7 @@ class TrainingArguments:
return ParallelMode.SAGEMAKER_MODEL_PARALLEL return ParallelMode.SAGEMAKER_MODEL_PARALLEL
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return ParallelMode.SAGEMAKER_DATA_PARALLEL return ParallelMode.SAGEMAKER_DATA_PARALLEL
elif self.deepspeed or self.distributed_state.distributed_type != DistributedType.NO: elif hasattr(self, "distributed_state") and self.distributed_state.distributed_type != DistributedType.NO:
return ParallelMode.DISTRIBUTED return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1: elif self.n_gpu > 1:
return ParallelMode.NOT_DISTRIBUTED return ParallelMode.NOT_DISTRIBUTED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment