Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7c4c6f60
Unverified
Commit
7c4c6f60
authored
Jun 29, 2022
by
Zachary Mueller
Committed by
GitHub
Jun 29, 2022
Browse files
Fix all is_torch_tpu_available issues (#17936)
* Fix all is_torch_tpu_available
parent
77b76672
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
22 additions
and
19 deletions
+22
-19
examples/pytorch/question-answering/trainer_qa.py
examples/pytorch/question-answering/trainer_qa.py
+1
-1
examples/pytorch/question-answering/trainer_seq2seq_qa.py
examples/pytorch/question-answering/trainer_seq2seq_qa.py
+1
-1
examples/research_projects/quantization-qdqbert/trainer_quant_qa.py
...esearch_projects/quantization-qdqbert/trainer_quant_qa.py
+1
-1
src/transformers/benchmark/benchmark_args.py
src/transformers/benchmark/benchmark_args.py
+1
-1
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+1
-1
src/transformers/trainer.py
src/transformers/trainer.py
+1
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+1
-1
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+2
-2
src/transformers/training_args.py
src/transformers/training_args.py
+1
-1
src/transformers/utils/import_utils.py
src/transformers/utils/import_utils.py
+12
-9
No files found.
examples/pytorch/question-answering/trainer_qa.py
View file @
7c4c6f60
...
@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available
...
@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available
from
transformers.trainer_utils
import
PredictionOutput
from
transformers.trainer_utils
import
PredictionOutput
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.metrics
as
met
import
torch_xla.debug.metrics
as
met
...
...
examples/pytorch/question-answering/trainer_seq2seq_qa.py
View file @
7c4c6f60
...
@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available
...
@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available
from
transformers.trainer_utils
import
PredictionOutput
from
transformers.trainer_utils
import
PredictionOutput
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.metrics
as
met
import
torch_xla.debug.metrics
as
met
...
...
examples/research_projects/quantization-qdqbert/trainer_quant_qa.py
View file @
7c4c6f60
...
@@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput
...
@@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.metrics
as
met
import
torch_xla.debug.metrics
as
met
...
...
src/transformers/benchmark/benchmark_args.py
View file @
7c4c6f60
...
@@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments
...
@@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
...
...
src/transformers/testing_utils.py
View file @
7c4c6f60
...
@@ -467,7 +467,7 @@ def require_torch_tpu(test_case):
...
@@ -467,7 +467,7 @@ def require_torch_tpu(test_case):
"""
"""
Decorator marking a test that requires a TPU (in PyTorch).
Decorator marking a test that requires a TPU (in PyTorch).
"""
"""
return
unittest
.
skipUnless
(
is_torch_tpu_available
(),
"test requires PyTorch TPU"
)(
test_case
)
return
unittest
.
skipUnless
(
is_torch_tpu_available
(
check_device
=
False
),
"test requires PyTorch TPU"
)(
test_case
)
if
is_torch_available
():
if
is_torch_available
():
...
...
src/transformers/trainer.py
View file @
7c4c6f60
...
@@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"):
...
@@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"):
if
is_datasets_available
():
if
is_datasets_available
():
import
datasets
import
datasets
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.metrics
as
met
import
torch_xla.debug.metrics
as
met
import
torch_xla.distributed.parallel_loader
as
pl
import
torch_xla.distributed.parallel_loader
as
pl
...
...
src/transformers/trainer_pt_utils.py
View file @
7c4c6f60
...
@@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_
...
@@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_
if
is_training_run_on_sagemaker
():
if
is_training_run_on_sagemaker
():
logging
.
add_handler
(
StreamHandler
(
sys
.
stdout
))
logging
.
add_handler
(
StreamHandler
(
sys
.
stdout
))
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
...
...
src/transformers/trainer_utils.py
View file @
7c4c6f60
...
@@ -307,7 +307,7 @@ def is_main_process(local_rank):
...
@@ -307,7 +307,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`.
`local_rank`.
"""
"""
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
True
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
return
xm
.
get_ordinal
()
==
0
return
xm
.
get_ordinal
()
==
0
...
@@ -318,7 +318,7 @@ def total_processes_number(local_rank):
...
@@ -318,7 +318,7 @@ def total_processes_number(local_rank):
"""
"""
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
"""
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
True
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
return
xm
.
xrt_world_size
()
return
xm
.
xrt_world_size
()
...
...
src/transformers/training_args.py
View file @
7c4c6f60
...
@@ -52,7 +52,7 @@ if is_torch_available():
...
@@ -52,7 +52,7 @@ if is_torch_available():
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
if
is_torch_tpu_available
():
if
is_torch_tpu_available
(
check_device
=
False
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
...
...
src/transformers/utils/import_utils.py
View file @
7c4c6f60
...
@@ -396,19 +396,22 @@ def is_ftfy_available():
...
@@ -396,19 +396,22 @@ def is_ftfy_available():
return
_ftfy_available
return
_ftfy_available
def
is_torch_tpu_available
():
def
is_torch_tpu_available
(
check_device
=
True
):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if
not
_torch_available
:
if
not
_torch_available
:
return
False
return
False
if
importlib
.
util
.
find_spec
(
"torch_xla"
)
is
None
:
if
importlib
.
util
.
find_spec
(
"torch_xla"
)
is
not
None
:
return
False
if
check_device
:
import
torch_xla.core.xla_model
as
xm
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try
:
import
torch_xla.core.xla_model
as
xm
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
_
=
xm
.
xla_device
()
try
:
return
True
xm
.
xla_device
()
except
RuntimeError
:
return
False
return
True
return
True
except
RuntimeError
:
return
False
return
False
def
is_torchdynamo_available
():
def
is_torchdynamo_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment