Unverified Commit 5bb4ec62 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Raise err if minimum Accelerate version isn't available (#22841)



* Add warning about accelerate

* Version block Accelerate

* Include parse

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Check partial state

* Update param

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 5f092194
......@@ -1531,6 +1531,10 @@ class TrainingArguments:
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True):
raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate`: Run `pip install --upgrade accelerate`"
)
if self.no_cuda:
self.distributed_state = PartialState(cpu=True)
device = self.distributed_state.device
......
......@@ -575,8 +575,15 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None
def is_accelerate_available():
return importlib.util.find_spec("accelerate") is not None
def is_accelerate_available(check_partial_state=False):
accelerate_available = importlib.util.find_spec("accelerate") is not None
if accelerate_available:
if check_partial_state:
return version.parse(importlib_metadata.version("accelerate")) >= version.parse("0.17.0")
else:
return True
else:
return False
def is_optimum_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment