Unverified Commit 1c6ab9e9 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[utils] account for MPS when available in get_device(). (#11905)

* account for MPS when available in get_device().

* fix
parent 265840a0
......@@ -175,6 +175,8 @@ def get_device():
return "npu"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment