"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "884e3b1c53099c8d88b3897b903eb79f7cc37c51"
Unverified Commit 7c4c6f60 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix all is_torch_tpu_available issues (#17936)

* Fix all is_torch_tpu_available 
parent 77b76672
...@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available ...@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput from transformers.trainer_utils import PredictionOutput
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
......
...@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available ...@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput from transformers.trainer_utils import PredictionOutput
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
......
...@@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput ...@@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
......
...@@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments ...@@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments
if is_torch_available(): if is_torch_available():
import torch import torch
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
......
...@@ -467,7 +467,7 @@ def require_torch_tpu(test_case): ...@@ -467,7 +467,7 @@ def require_torch_tpu(test_case):
""" """
Decorator marking a test that requires a TPU (in PyTorch). Decorator marking a test that requires a TPU (in PyTorch).
""" """
return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case) return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
if is_torch_available(): if is_torch_available():
......
...@@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"): ...@@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"):
if is_datasets_available(): if is_datasets_available():
import datasets import datasets
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.parallel_loader as pl
......
...@@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_ ...@@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_
if is_training_run_on_sagemaker(): if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout)) logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
......
...@@ -307,7 +307,7 @@ def is_main_process(local_rank): ...@@ -307,7 +307,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`. `local_rank`.
""" """
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0 return xm.get_ordinal() == 0
...@@ -318,7 +318,7 @@ def total_processes_number(local_rank): ...@@ -318,7 +318,7 @@ def total_processes_number(local_rank):
""" """
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
""" """
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.xrt_world_size() return xm.xrt_world_size()
......
...@@ -52,7 +52,7 @@ if is_torch_available(): ...@@ -52,7 +52,7 @@ if is_torch_available():
import torch import torch
import torch.distributed as dist import torch.distributed as dist
if is_torch_tpu_available(): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
......
...@@ -396,19 +396,22 @@ def is_ftfy_available(): ...@@ -396,19 +396,22 @@ def is_ftfy_available():
return _ftfy_available return _ftfy_available
def is_torch_tpu_available(): def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if not _torch_available: if not _torch_available:
return False return False
if importlib.util.find_spec("torch_xla") is None: if importlib.util.find_spec("torch_xla") is not None:
return False if check_device:
import torch_xla.core.xla_model as xm # We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
import torch_xla.core.xla_model as xm
# We need to check if `xla_device` can be found, will raise a RuntimeError if not _ = xm.xla_device()
try: return True
xm.xla_device() except RuntimeError:
return False
return True return True
except RuntimeError: return False
return False
def is_torchdynamo_available(): def is_torchdynamo_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment