Unverified Commit f8325cfd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[MPS] Make sure it doesn't break torch < 1.12 (#425)

* [MPS] Make sure it doesn't break torch < 1.12

* up
parent 8d9c4a53
......@@ -5,10 +5,15 @@ from distutils.util import strtobool
import torch
from packaging import version
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
def parse_flag_from_env(key, default=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment