Unverified Commit 7c4c6f60 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix all is_torch_tpu_available issues (#17936)

* Fix all is_torch_tpu_available 
parent 77b76672
......@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
......
......@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
......
......@@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput
logger = logging.getLogger(__name__)
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
......
......@@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
import torch
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
......
......@@ -467,7 +467,7 @@ def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
if is_torch_available():
......
......@@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"):
if is_datasets_available():
import datasets
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
......
......@@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
......
......@@ -307,7 +307,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`.
"""
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0
......@@ -318,7 +318,7 @@ def total_processes_number(local_rank):
"""
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm
return xm.xrt_world_size()
......
......@@ -52,7 +52,7 @@ if is_torch_available():
import torch
import torch.distributed as dist
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
......
......@@ -396,19 +396,22 @@ def is_ftfy_available():
return _ftfy_available
def is_torch_tpu_available():
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if not _torch_available:
return False
if importlib.util.find_spec("torch_xla") is None:
return False
import torch_xla.core.xla_model as xm
if importlib.util.find_spec("torch_xla") is not None:
if check_device:
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
xm.xla_device()
import torch_xla.core.xla_model as xm
_ = xm.xla_device()
return True
except RuntimeError:
return False
return True
return False
def is_torchdynamo_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment