Unverified Commit bd9dfc23 authored by Nripesh Niketan's avatar Nripesh Niketan Committed by GitHub
Browse files

Add `is_torch_mps_available` function to utils (#24660)

* Add mps function utils

* black formating

* format fix

* Added MPS functionality to transformers

* format fix
parent ee339bad
......@@ -116,6 +116,7 @@ from .utils import (
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_mps_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
......
......@@ -35,6 +35,7 @@ from .utils import (
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_mps_available,
is_torch_tpu_available,
requires_backends,
)
......@@ -411,6 +412,11 @@ class TrainerMemoryTracker:
if is_torch_cuda_available():
import torch
self.torch = torch
self.gpu = {}
elif is_torch_mps_available():
import torch
self.torch = torch
self.gpu = {}
else:
......
......@@ -161,6 +161,7 @@ from .import_utils import (
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
......
......@@ -249,6 +249,15 @@ def is_torch_cuda_available():
return False
def is_torch_mps_available():
if is_torch_available():
import torch
if hasattr(torch.backends, "mps"):
return torch.backends.mps.is_available()
return False
def is_torch_bf16_gpu_available():
if not is_torch_available():
return False
......
......@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
uniform_init_parms = ["conv"]
ignore_init = ["lstm"]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif not any([x in name for x in ignore_init]):
elif not any(x in name for x in ignore_init):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment