"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7424b2848f57022fcbae90cc10079935d24e6e59"
Unverified Commit 84bac652 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Move import check to before state reset (#23906)

* Move import check to before state reset

* Guard better
parent e42869b0
...@@ -1667,12 +1667,12 @@ class TrainingArguments: ...@@ -1667,12 +1667,12 @@ class TrainingArguments:
def _setup_devices(self) -> "torch.device": def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
AcceleratorState._reset_state() if not is_sagemaker_mp_enabled():
PartialState._reset_state() if not is_accelerate_available(check_partial_state=True):
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True): raise ImportError(
raise ImportError( "Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" )
) AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None self.distributed_state = None
if self.no_cuda: if self.no_cuda:
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment