"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "da9754a3a0d14f85668d3558ef3aaef6dc8e8dfe"
Unverified Commit 7b23a582 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Replaces xxx_required with requires_backends (#20715)

* Replaces xxx_required with requires_backends

* Fixup
parent 7c9e2f24
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -76,8 +76,8 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): ...@@ -76,8 +76,8 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
) )
@cached_property @cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]: def _setup_devices(self) -> Tuple["torch.device", int]:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if not self.cuda: if not self.cuda:
device = torch.device("cpu") device = torch.device("cpu")
...@@ -95,19 +95,19 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): ...@@ -95,19 +95,19 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
return is_torch_tpu_available() and self.tpu return is_torch_tpu_available() and self.tpu
@property @property
@torch_required
def device_idx(self) -> int: def device_idx(self) -> int:
requires_backends(self, ["torch"])
# TODO(PVP): currently only single GPU is supported # TODO(PVP): currently only single GPU is supported
return torch.cuda.current_device() return torch.cuda.current_device()
@property @property
@torch_required
def device(self) -> "torch.device": def device(self) -> "torch.device":
requires_backends(self, ["torch"])
return self._setup_devices[0] return self._setup_devices[0]
@property @property
@torch_required
def n_gpu(self): def n_gpu(self):
requires_backends(self, ["torch"])
return self._setup_devices[1] return self._setup_devices[1]
@property @property
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..utils import cached_property, is_tf_available, logging, tf_required from ..utils import cached_property, is_tf_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -77,8 +77,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments): ...@@ -77,8 +77,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
) )
@cached_property @cached_property
@tf_required
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]: def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
tpu = None tpu = None
if self.tpu: if self.tpu:
try: try:
...@@ -91,8 +91,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments): ...@@ -91,8 +91,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
return tpu return tpu
@cached_property @cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]: def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
if self.is_tpu: if self.is_tpu:
tf.config.experimental_connect_to_cluster(self._setup_tpu) tf.config.experimental_connect_to_cluster(self._setup_tpu)
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu) tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
...@@ -111,23 +111,23 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments): ...@@ -111,23 +111,23 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
return strategy return strategy
@property @property
@tf_required
def is_tpu(self) -> bool: def is_tpu(self) -> bool:
requires_backends(self, ["tf"])
return self._setup_tpu is not None return self._setup_tpu is not None
@property @property
@tf_required
def strategy(self) -> "tf.distribute.Strategy": def strategy(self) -> "tf.distribute.Strategy":
requires_backends(self, ["tf"])
return self._setup_strategy return self._setup_strategy
@property @property
@tf_required
def gpu_list(self): def gpu_list(self):
requires_backends(self, ["tf"])
return tf.config.list_physical_devices("GPU") return tf.config.list_physical_devices("GPU")
@property @property
@tf_required
def n_gpu(self) -> int: def n_gpu(self) -> int:
requires_backends(self, ["tf"])
if self.cuda: if self.cuda:
return len(self.gpu_list) return len(self.gpu_list)
return 0 return 0
......
...@@ -42,7 +42,7 @@ from .utils import ( ...@@ -42,7 +42,7 @@ from .utils import (
is_torch_device, is_torch_device,
is_torch_dtype, is_torch_dtype,
logging, logging,
torch_required, requires_backends,
) )
...@@ -175,7 +175,6 @@ class BatchFeature(UserDict): ...@@ -175,7 +175,6 @@ class BatchFeature(UserDict):
return self return self
@torch_required
def to(self, *args, **kwargs) -> "BatchFeature": def to(self, *args, **kwargs) -> "BatchFeature":
""" """
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
...@@ -190,6 +189,7 @@ class BatchFeature(UserDict): ...@@ -190,6 +189,7 @@ class BatchFeature(UserDict):
Returns: Returns:
[`BatchFeature`]: The same instance after modification. [`BatchFeature`]: The same instance after modification.
""" """
requires_backends(self, ["torch"])
import torch # noqa import torch # noqa
new_data = {} new_data = {}
......
...@@ -127,10 +127,8 @@ from .utils import ( ...@@ -127,10 +127,8 @@ from .utils import (
is_vision_available, is_vision_available,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
tf_required,
to_numpy, to_numpy,
to_py_obj, to_py_obj,
torch_only_method, torch_only_method,
torch_required,
torch_version, torch_version,
) )
...@@ -56,8 +56,8 @@ from .utils import ( ...@@ -56,8 +56,8 @@ from .utils import (
is_torch_device, is_torch_device,
is_torch_tensor, is_torch_tensor,
logging, logging,
requires_backends,
to_py_obj, to_py_obj,
torch_required,
) )
...@@ -739,7 +739,6 @@ class BatchEncoding(UserDict): ...@@ -739,7 +739,6 @@ class BatchEncoding(UserDict):
return self return self
@torch_required
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
""" """
Send all values to device by calling `v.to(device)` (PyTorch only). Send all values to device by calling `v.to(device)` (PyTorch only).
...@@ -750,6 +749,7 @@ class BatchEncoding(UserDict): ...@@ -750,6 +749,7 @@ class BatchEncoding(UserDict):
Returns: Returns:
[`BatchEncoding`]: The same instance after modification. [`BatchEncoding`]: The same instance after modification.
""" """
requires_backends(self, ["torch"])
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
......
...@@ -50,7 +50,6 @@ from .utils import ( ...@@ -50,7 +50,6 @@ from .utils import (
is_torch_tpu_available, is_torch_tpu_available,
logging, logging,
requires_backends, requires_backends,
torch_required,
) )
...@@ -1386,8 +1385,8 @@ class TrainingArguments: ...@@ -1386,8 +1385,8 @@ class TrainingArguments:
return timedelta(seconds=self.ddp_timeout) return timedelta(seconds=self.ddp_timeout)
@cached_property @cached_property
@torch_required
def _setup_devices(self) -> "torch.device": def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1: if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning( logger.warning(
...@@ -1537,15 +1536,14 @@ class TrainingArguments: ...@@ -1537,15 +1536,14 @@ class TrainingArguments:
return device return device
@property @property
@torch_required
def device(self) -> "torch.device": def device(self) -> "torch.device":
""" """
The device used by this process. The device used by this process.
""" """
requires_backends(self, ["torch"])
return self._setup_devices return self._setup_devices
@property @property
@torch_required
def n_gpu(self): def n_gpu(self):
""" """
The number of GPUs used by this process. The number of GPUs used by this process.
...@@ -1554,12 +1552,12 @@ class TrainingArguments: ...@@ -1554,12 +1552,12 @@ class TrainingArguments:
This will only be greater than one when you have multiple GPUs available but are not using distributed This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1. training. For distributed training, it will always be 1.
""" """
requires_backends(self, ["torch"])
# Make sure `self._n_gpu` is properly setup. # Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices _ = self._setup_devices
return self._n_gpu return self._n_gpu
@property @property
@torch_required
def parallel_mode(self): def parallel_mode(self):
""" """
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of: The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
...@@ -1570,6 +1568,7 @@ class TrainingArguments: ...@@ -1570,6 +1568,7 @@ class TrainingArguments:
`torch.nn.DistributedDataParallel`). `torch.nn.DistributedDataParallel`).
- `ParallelMode.TPU`: several TPU cores. - `ParallelMode.TPU`: several TPU cores.
""" """
requires_backends(self, ["torch"])
if is_torch_tpu_available(): if is_torch_tpu_available():
return ParallelMode.TPU return ParallelMode.TPU
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
...@@ -1584,11 +1583,12 @@ class TrainingArguments: ...@@ -1584,11 +1583,12 @@ class TrainingArguments:
return ParallelMode.NOT_PARALLEL return ParallelMode.NOT_PARALLEL
@property @property
@torch_required
def world_size(self): def world_size(self):
""" """
The number of processes used in parallel. The number of processes used in parallel.
""" """
requires_backends(self, ["torch"])
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.xrt_world_size() return xm.xrt_world_size()
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
...@@ -1600,11 +1600,11 @@ class TrainingArguments: ...@@ -1600,11 +1600,11 @@ class TrainingArguments:
return 1 return 1
@property @property
@torch_required
def process_index(self): def process_index(self):
""" """
The index of the current process used. The index of the current process used.
""" """
requires_backends(self, ["torch"])
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.get_ordinal() return xm.get_ordinal()
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
...@@ -1616,11 +1616,11 @@ class TrainingArguments: ...@@ -1616,11 +1616,11 @@ class TrainingArguments:
return 0 return 0
@property @property
@torch_required
def local_process_index(self): def local_process_index(self):
""" """
The index of the local process used. The index of the local process used.
""" """
requires_backends(self, ["torch"])
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.get_local_ordinal() return xm.get_local_ordinal()
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
......
...@@ -17,7 +17,7 @@ from dataclasses import dataclass, field ...@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from typing import Optional, Tuple from typing import Optional, Tuple
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required from .utils import cached_property, is_tf_available, logging, requires_backends
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments): ...@@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"}) xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
@cached_property @cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
requires_backends(self, ["tf"])
logger.info("Tensorflow: setting up strategy") logger.info("Tensorflow: setting up strategy")
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
...@@ -234,19 +234,19 @@ class TFTrainingArguments(TrainingArguments): ...@@ -234,19 +234,19 @@ class TFTrainingArguments(TrainingArguments):
return strategy return strategy
@property @property
@tf_required
def strategy(self) -> "tf.distribute.Strategy": def strategy(self) -> "tf.distribute.Strategy":
""" """
The strategy used for distributed training. The strategy used for distributed training.
""" """
requires_backends(self, ["tf"])
return self._setup_strategy return self._setup_strategy
@property @property
@tf_required
def n_replicas(self) -> int: def n_replicas(self) -> int:
""" """
The number of replicas (CPUs, GPUs or TPU cores) used in this training. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
""" """
requires_backends(self, ["tf"])
return self._setup_strategy.num_replicas_in_sync return self._setup_strategy.num_replicas_in_sync
@property @property
...@@ -276,11 +276,11 @@ class TFTrainingArguments(TrainingArguments): ...@@ -276,11 +276,11 @@ class TFTrainingArguments(TrainingArguments):
return per_device_batch_size * self.n_replicas return per_device_batch_size * self.n_replicas
@property @property
@tf_required
def n_gpu(self) -> int: def n_gpu(self) -> int:
""" """
The number of replicas (CPUs, GPUs or TPU cores) used in this training. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
""" """
requires_backends(self, ["tf"])
warnings.warn( warnings.warn(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.", "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
FutureWarning, FutureWarning,
......
...@@ -163,9 +163,7 @@ from .import_utils import ( ...@@ -163,9 +163,7 @@ from .import_utils import (
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_vision_available, is_vision_available,
requires_backends, requires_backends,
tf_required,
torch_only_method, torch_only_method,
torch_required,
torch_version, torch_version,
) )
......
...@@ -22,7 +22,7 @@ import shutil ...@@ -22,7 +22,7 @@ import shutil
import sys import sys
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from functools import lru_cache, wraps from functools import lru_cache
from itertools import chain from itertools import chain
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any
...@@ -1039,30 +1039,6 @@ class DummyObject(type): ...@@ -1039,30 +1039,6 @@ class DummyObject(type):
requires_backends(cls, cls._backends) requires_backends(cls, cls._backends)
def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
return wrapper
def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
def is_torch_fx_proxy(x): def is_torch_fx_proxy(x):
if is_torch_fx_available(): if is_torch_fx_available():
import torch.fx import torch.fx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment