Unverified Commit bd9dfc23 authored by Nripesh Niketan's avatar Nripesh Niketan Committed by GitHub
Browse files

Add `is_torch_mps_available` function to utils (#24660)

* Add mps function utils

* black formating

* format fix

* Added MPS functionality to transformers

* format fix
parent ee339bad
...@@ -116,6 +116,7 @@ from .utils import ( ...@@ -116,6 +116,7 @@ from .utils import (
is_torch_cuda_available, is_torch_cuda_available,
is_torch_fx_available, is_torch_fx_available,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torch_mps_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
......
...@@ -35,6 +35,7 @@ from .utils import ( ...@@ -35,6 +35,7 @@ from .utils import (
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mps_available,
is_torch_tpu_available, is_torch_tpu_available,
requires_backends, requires_backends,
) )
...@@ -411,6 +412,11 @@ class TrainerMemoryTracker: ...@@ -411,6 +412,11 @@ class TrainerMemoryTracker:
if is_torch_cuda_available(): if is_torch_cuda_available():
import torch import torch
self.torch = torch
self.gpu = {}
elif is_torch_mps_available():
import torch
self.torch = torch self.torch = torch
self.gpu = {} self.gpu = {}
else: else:
......
...@@ -161,6 +161,7 @@ from .import_utils import ( ...@@ -161,6 +161,7 @@ from .import_utils import (
is_torch_cuda_available, is_torch_cuda_available,
is_torch_fx_available, is_torch_fx_available,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torch_mps_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
......
...@@ -249,6 +249,15 @@ def is_torch_cuda_available(): ...@@ -249,6 +249,15 @@ def is_torch_cuda_available():
return False return False
def is_torch_mps_available():
if is_torch_available():
import torch
if hasattr(torch.backends, "mps"):
return torch.backends.mps.is_available()
return False
def is_torch_bf16_gpu_available(): def is_torch_bf16_gpu_available():
if not is_torch_available(): if not is_torch_available():
return False return False
......
...@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
uniform_init_parms = ["conv"] uniform_init_parms = ["conv"]
ignore_init = ["lstm"] ignore_init = ["lstm"]
if param.requires_grad: if param.requires_grad:
if any([x in name for x in uniform_init_parms]): if any(x in name for x in uniform_init_parms):
self.assertTrue( self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
elif not any([x in name for x in ignore_init]): elif not any(x in name for x in ignore_init):
self.assertIn( self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(), ((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0], [0.0, 1.0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment