Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7b23a582
Unverified
Commit
7b23a582
authored
Dec 14, 2022
by
amyeroberts
Committed by
GitHub
Dec 14, 2022
Browse files
Replaces xxx_required with requires_backends (#20715)
* Replaces xxx_required with requires_backends * Fixup
parent
7c9e2f24
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
30 additions
and
58 deletions
+30
-58
src/transformers/benchmark/benchmark_args.py
src/transformers/benchmark/benchmark_args.py
+5
-5
src/transformers/benchmark/benchmark_args_tf.py
src/transformers/benchmark/benchmark_args_tf.py
+7
-7
src/transformers/feature_extraction_utils.py
src/transformers/feature_extraction_utils.py
+2
-2
src/transformers/file_utils.py
src/transformers/file_utils.py
+0
-2
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+2
-2
src/transformers/training_args.py
src/transformers/training_args.py
+8
-8
src/transformers/training_args_tf.py
src/transformers/training_args_tf.py
+5
-5
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+0
-2
src/transformers/utils/import_utils.py
src/transformers/utils/import_utils.py
+1
-25
No files found.
src/transformers/benchmark/benchmark_args.py
View file @
7b23a582
...
...
@@ -17,7 +17,7 @@
from
dataclasses
import
dataclass
,
field
from
typing
import
Tuple
from
..utils
import
cached_property
,
is_torch_available
,
is_torch_tpu_available
,
logging
,
torch_
require
d
from
..utils
import
cached_property
,
is_torch_available
,
is_torch_tpu_available
,
logging
,
require
s_backends
from
.benchmark_args_utils
import
BenchmarkArguments
...
...
@@ -76,8 +76,8 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
)
@
cached_property
@
torch_required
def
_setup_devices
(
self
)
->
Tuple
[
"torch.device"
,
int
]:
requires_backends
(
self
,
[
"torch"
])
logger
.
info
(
"PyTorch: setting up devices"
)
if
not
self
.
cuda
:
device
=
torch
.
device
(
"cpu"
)
...
...
@@ -95,19 +95,19 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
return
is_torch_tpu_available
()
and
self
.
tpu
@
property
@
torch_required
def
device_idx
(
self
)
->
int
:
requires_backends
(
self
,
[
"torch"
])
# TODO(PVP): currently only single GPU is supported
return
torch
.
cuda
.
current_device
()
@
property
@
torch_required
def
device
(
self
)
->
"torch.device"
:
requires_backends
(
self
,
[
"torch"
])
return
self
.
_setup_devices
[
0
]
@
property
@
torch_required
def
n_gpu
(
self
):
requires_backends
(
self
,
[
"torch"
])
return
self
.
_setup_devices
[
1
]
@
property
...
...
src/transformers/benchmark/benchmark_args_tf.py
View file @
7b23a582
...
...
@@ -17,7 +17,7 @@
from
dataclasses
import
dataclass
,
field
from
typing
import
Tuple
from
..utils
import
cached_property
,
is_tf_available
,
logging
,
tf_
require
d
from
..utils
import
cached_property
,
is_tf_available
,
logging
,
require
s_backends
from
.benchmark_args_utils
import
BenchmarkArguments
...
...
@@ -77,8 +77,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
)
@
cached_property
@
tf_required
def
_setup_tpu
(
self
)
->
Tuple
[
"tf.distribute.cluster_resolver.TPUClusterResolver"
]:
requires_backends
(
self
,
[
"tf"
])
tpu
=
None
if
self
.
tpu
:
try
:
...
...
@@ -91,8 +91,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
return
tpu
@
cached_property
@
tf_required
def
_setup_strategy
(
self
)
->
Tuple
[
"tf.distribute.Strategy"
,
"tf.distribute.cluster_resolver.TPUClusterResolver"
]:
requires_backends
(
self
,
[
"tf"
])
if
self
.
is_tpu
:
tf
.
config
.
experimental_connect_to_cluster
(
self
.
_setup_tpu
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
self
.
_setup_tpu
)
...
...
@@ -111,23 +111,23 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
return
strategy
@
property
@
tf_required
def
is_tpu
(
self
)
->
bool
:
requires_backends
(
self
,
[
"tf"
])
return
self
.
_setup_tpu
is
not
None
@
property
@
tf_required
def
strategy
(
self
)
->
"tf.distribute.Strategy"
:
requires_backends
(
self
,
[
"tf"
])
return
self
.
_setup_strategy
@
property
@
tf_required
def
gpu_list
(
self
):
requires_backends
(
self
,
[
"tf"
])
return
tf
.
config
.
list_physical_devices
(
"GPU"
)
@
property
@
tf_required
def
n_gpu
(
self
)
->
int
:
requires_backends
(
self
,
[
"tf"
])
if
self
.
cuda
:
return
len
(
self
.
gpu_list
)
return
0
...
...
src/transformers/feature_extraction_utils.py
View file @
7b23a582
...
...
@@ -42,7 +42,7 @@ from .utils import (
is_torch_device
,
is_torch_dtype
,
logging
,
torch_
require
d
,
require
s_backends
,
)
...
...
@@ -175,7 +175,6 @@ class BatchFeature(UserDict):
return
self
@
torch_required
def
to
(
self
,
*
args
,
**
kwargs
)
->
"BatchFeature"
:
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
...
...
@@ -190,6 +189,7 @@ class BatchFeature(UserDict):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends
(
self
,
[
"torch"
])
import
torch
# noqa
new_data
=
{}
...
...
src/transformers/file_utils.py
View file @
7b23a582
...
...
@@ -127,10 +127,8 @@ from .utils import (
is_vision_available
,
replace_return_docstrings
,
requires_backends
,
tf_required
,
to_numpy
,
to_py_obj
,
torch_only_method
,
torch_required
,
torch_version
,
)
src/transformers/tokenization_utils_base.py
View file @
7b23a582
...
...
@@ -56,8 +56,8 @@ from .utils import (
is_torch_device
,
is_torch_tensor
,
logging
,
requires_backends
,
to_py_obj
,
torch_required
,
)
...
...
@@ -739,7 +739,6 @@ class BatchEncoding(UserDict):
return
self
@
torch_required
def
to
(
self
,
device
:
Union
[
str
,
"torch.device"
])
->
"BatchEncoding"
:
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
...
...
@@ -750,6 +749,7 @@ class BatchEncoding(UserDict):
Returns:
[`BatchEncoding`]: The same instance after modification.
"""
requires_backends
(
self
,
[
"torch"
])
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
...
...
src/transformers/training_args.py
View file @
7b23a582
...
...
@@ -50,7 +50,6 @@ from .utils import (
is_torch_tpu_available
,
logging
,
requires_backends
,
torch_required
,
)
...
...
@@ -1386,8 +1385,8 @@ class TrainingArguments:
return
timedelta
(
seconds
=
self
.
ddp_timeout
)
@
cached_property
@
torch_required
def
_setup_devices
(
self
)
->
"torch.device"
:
requires_backends
(
self
,
[
"torch"
])
logger
.
info
(
"PyTorch: setting up devices"
)
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
and
self
.
local_rank
==
-
1
:
logger
.
warning
(
...
...
@@ -1537,15 +1536,14 @@ class TrainingArguments:
return
device
@
property
@
torch_required
def
device
(
self
)
->
"torch.device"
:
"""
The device used by this process.
"""
requires_backends
(
self
,
[
"torch"
])
return
self
.
_setup_devices
@
property
@
torch_required
def
n_gpu
(
self
):
"""
The number of GPUs used by this process.
...
...
@@ -1554,12 +1552,12 @@ class TrainingArguments:
This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1.
"""
requires_backends
(
self
,
[
"torch"
])
# Make sure `self._n_gpu` is properly setup.
_
=
self
.
_setup_devices
return
self
.
_n_gpu
@
property
@
torch_required
def
parallel_mode
(
self
):
"""
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
...
...
@@ -1570,6 +1568,7 @@ class TrainingArguments:
`torch.nn.DistributedDataParallel`).
- `ParallelMode.TPU`: several TPU cores.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
ParallelMode
.
TPU
elif
is_sagemaker_mp_enabled
():
...
...
@@ -1584,11 +1583,12 @@ class TrainingArguments:
return
ParallelMode
.
NOT_PARALLEL
@
property
@
torch_required
def
world_size
(
self
):
"""
The number of processes used in parallel.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
xm
.
xrt_world_size
()
elif
is_sagemaker_mp_enabled
():
...
...
@@ -1600,11 +1600,11 @@ class TrainingArguments:
return
1
@
property
@
torch_required
def
process_index
(
self
):
"""
The index of the current process used.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
xm
.
get_ordinal
()
elif
is_sagemaker_mp_enabled
():
...
...
@@ -1616,11 +1616,11 @@ class TrainingArguments:
return
0
@
property
@
torch_required
def
local_process_index
(
self
):
"""
The index of the local process used.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
xm
.
get_local_ordinal
()
elif
is_sagemaker_mp_enabled
():
...
...
src/transformers/training_args_tf.py
View file @
7b23a582
...
...
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from
typing
import
Optional
,
Tuple
from
.training_args
import
TrainingArguments
from
.utils
import
cached_property
,
is_tf_available
,
logging
,
tf_
require
d
from
.utils
import
cached_property
,
is_tf_available
,
logging
,
require
s_backends
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
xla
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to activate the XLA compilation or not"
})
@
cached_property
@
tf_required
def
_setup_strategy
(
self
)
->
Tuple
[
"tf.distribute.Strategy"
,
int
]:
requires_backends
(
self
,
[
"tf"
])
logger
.
info
(
"Tensorflow: setting up strategy"
)
gpus
=
tf
.
config
.
list_physical_devices
(
"GPU"
)
...
...
@@ -234,19 +234,19 @@ class TFTrainingArguments(TrainingArguments):
return
strategy
@
property
@
tf_required
def
strategy
(
self
)
->
"tf.distribute.Strategy"
:
"""
The strategy used for distributed training.
"""
requires_backends
(
self
,
[
"tf"
])
return
self
.
_setup_strategy
@
property
@
tf_required
def
n_replicas
(
self
)
->
int
:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends
(
self
,
[
"tf"
])
return
self
.
_setup_strategy
.
num_replicas_in_sync
@
property
...
...
@@ -276,11 +276,11 @@ class TFTrainingArguments(TrainingArguments):
return
per_device_batch_size
*
self
.
n_replicas
@
property
@
tf_required
def
n_gpu
(
self
)
->
int
:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends
(
self
,
[
"tf"
])
warnings
.
warn
(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead."
,
FutureWarning
,
...
...
src/transformers/utils/__init__.py
View file @
7b23a582
...
...
@@ -163,9 +163,7 @@ from .import_utils import (
is_training_run_on_sagemaker
,
is_vision_available
,
requires_backends
,
tf_required
,
torch_only_method
,
torch_required
,
torch_version
,
)
...
...
src/transformers/utils/import_utils.py
View file @
7b23a582
...
...
@@ -22,7 +22,7 @@ import shutil
import
sys
import
warnings
from
collections
import
OrderedDict
from
functools
import
lru_cache
,
wraps
from
functools
import
lru_cache
from
itertools
import
chain
from
types
import
ModuleType
from
typing
import
Any
...
...
@@ -1039,30 +1039,6 @@ class DummyObject(type):
requires_backends
(
cls
,
cls
.
_backends
)
def
torch_required
(
func
):
# Chose a different decorator name than in tests so it's clear they are not the same.
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_torch_available
():
return
func
(
*
args
,
**
kwargs
)
else
:
raise
ImportError
(
f
"Method `
{
func
.
__name__
}
` requires PyTorch."
)
return
wrapper
def
tf_required
(
func
):
# Chose a different decorator name than in tests so it's clear they are not the same.
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_tf_available
():
return
func
(
*
args
,
**
kwargs
)
else
:
raise
ImportError
(
f
"Method `
{
func
.
__name__
}
` requires TF."
)
return
wrapper
def
is_torch_fx_proxy
(
x
):
if
is_torch_fx_available
():
import
torch.fx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment