Unverified Commit b455ad0a authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix parallel mode check (#23409)

* Fix sagemaker/distributed state

* Fix correctly

* Bring back -1

* Bring back local rank for distributed check

* better version

* Cleanest option
parent db4d7652
...@@ -1613,6 +1613,7 @@ class TrainingArguments: ...@@ -1613,6 +1613,7 @@ class TrainingArguments:
raise ImportError( raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" "Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
) )
self.distributed_state = None
if self.no_cuda: if self.no_cuda:
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
self._n_gpu = 0 self._n_gpu = 0
...@@ -1636,7 +1637,7 @@ class TrainingArguments: ...@@ -1636,7 +1637,7 @@ class TrainingArguments:
if ( if (
torch.distributed.is_available() torch.distributed.is_available()
and torch.distributed.is_initialized() and torch.distributed.is_initialized()
and self.distributed_state.distributed_type == DistributedType.NO and self.parallel_mode != ParallelMode.DISTRIBUTED
): ):
logger.warning( logger.warning(
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
...@@ -1728,7 +1729,9 @@ class TrainingArguments: ...@@ -1728,7 +1729,9 @@ class TrainingArguments:
return ParallelMode.SAGEMAKER_MODEL_PARALLEL return ParallelMode.SAGEMAKER_MODEL_PARALLEL
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return ParallelMode.SAGEMAKER_DATA_PARALLEL return ParallelMode.SAGEMAKER_DATA_PARALLEL
elif hasattr(self, "distributed_state") and self.distributed_state.distributed_type != DistributedType.NO: elif (
self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO
) or (self.distributed_state is None and self.local_rank != -1):
return ParallelMode.DISTRIBUTED return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1: elif self.n_gpu > 1:
return ParallelMode.NOT_DISTRIBUTED return ParallelMode.NOT_DISTRIBUTED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment