Unverified Commit 7b23a582 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Replaces xxx_required with requires_backends (#20715)

* Replaces xxx_required with requires_backends

* Fixup
parent 7c9e2f24
......@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments
......@@ -76,8 +76,8 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
)
@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not self.cuda:
device = torch.device("cpu")
......@@ -95,19 +95,19 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
return is_torch_tpu_available() and self.tpu
@property
@torch_required
def device_idx(self) -> int:
requires_backends(self, ["torch"])
# TODO(PVP): currently only single GPU is supported
return torch.cuda.current_device()
@property
@torch_required
def device(self) -> "torch.device":
requires_backends(self, ["torch"])
return self._setup_devices[0]
@property
@torch_required
def n_gpu(self):
requires_backends(self, ["torch"])
return self._setup_devices[1]
@property
......
......@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple
from ..utils import cached_property, is_tf_available, logging, tf_required
from ..utils import cached_property, is_tf_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments
......@@ -77,8 +77,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
)
@cached_property
@tf_required
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
tpu = None
if self.tpu:
try:
......@@ -91,8 +91,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
return tpu
@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
if self.is_tpu:
tf.config.experimental_connect_to_cluster(self._setup_tpu)
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
......@@ -111,23 +111,23 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
return strategy
@property
@tf_required
def is_tpu(self) -> bool:
requires_backends(self, ["tf"])
return self._setup_tpu is not None
@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
requires_backends(self, ["tf"])
return self._setup_strategy
@property
@tf_required
def gpu_list(self):
requires_backends(self, ["tf"])
return tf.config.list_physical_devices("GPU")
@property
@tf_required
def n_gpu(self) -> int:
requires_backends(self, ["tf"])
if self.cuda:
return len(self.gpu_list)
return 0
......
......@@ -42,7 +42,7 @@ from .utils import (
is_torch_device,
is_torch_dtype,
logging,
torch_required,
requires_backends,
)
......@@ -175,7 +175,6 @@ class BatchFeature(UserDict):
return self
@torch_required
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
......@@ -190,6 +189,7 @@ class BatchFeature(UserDict):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
......
......@@ -127,10 +127,8 @@ from .utils import (
is_vision_available,
replace_return_docstrings,
requires_backends,
tf_required,
to_numpy,
to_py_obj,
torch_only_method,
torch_required,
torch_version,
)
......@@ -56,8 +56,8 @@ from .utils import (
is_torch_device,
is_torch_tensor,
logging,
requires_backends,
to_py_obj,
torch_required,
)
......@@ -739,7 +739,6 @@ class BatchEncoding(UserDict):
return self
@torch_required
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
......@@ -750,6 +749,7 @@ class BatchEncoding(UserDict):
Returns:
[`BatchEncoding`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
......
......@@ -50,7 +50,6 @@ from .utils import (
is_torch_tpu_available,
logging,
requires_backends,
torch_required,
)
......@@ -1386,8 +1385,8 @@ class TrainingArguments:
return timedelta(seconds=self.ddp_timeout)
@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
......@@ -1537,15 +1536,14 @@ class TrainingArguments:
return device
@property
@torch_required
def device(self) -> "torch.device":
"""
The device used by this process.
"""
requires_backends(self, ["torch"])
return self._setup_devices
@property
@torch_required
def n_gpu(self):
"""
The number of GPUs used by this process.
......@@ -1554,12 +1552,12 @@ class TrainingArguments:
This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1.
"""
requires_backends(self, ["torch"])
# Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices
return self._n_gpu
@property
@torch_required
def parallel_mode(self):
"""
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
......@@ -1570,6 +1568,7 @@ class TrainingArguments:
`torch.nn.DistributedDataParallel`).
- `ParallelMode.TPU`: several TPU cores.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return ParallelMode.TPU
elif is_sagemaker_mp_enabled():
......@@ -1584,11 +1583,12 @@ class TrainingArguments:
return ParallelMode.NOT_PARALLEL
@property
@torch_required
def world_size(self):
"""
The number of processes used in parallel.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_mp_enabled():
......@@ -1600,11 +1600,11 @@ class TrainingArguments:
return 1
@property
@torch_required
def process_index(self):
"""
The index of the current process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_mp_enabled():
......@@ -1616,11 +1616,11 @@ class TrainingArguments:
return 0
@property
@torch_required
def local_process_index(self):
"""
The index of the local process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_local_ordinal()
elif is_sagemaker_mp_enabled():
......
......@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from typing import Optional, Tuple
from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
from .utils import cached_property, is_tf_available, logging, requires_backends
logger = logging.get_logger(__name__)
......@@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
requires_backends(self, ["tf"])
logger.info("Tensorflow: setting up strategy")
gpus = tf.config.list_physical_devices("GPU")
......@@ -234,19 +234,19 @@ class TFTrainingArguments(TrainingArguments):
return strategy
@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
"""
The strategy used for distributed training.
"""
requires_backends(self, ["tf"])
return self._setup_strategy
@property
@tf_required
def n_replicas(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends(self, ["tf"])
return self._setup_strategy.num_replicas_in_sync
@property
......@@ -276,11 +276,11 @@ class TFTrainingArguments(TrainingArguments):
return per_device_batch_size * self.n_replicas
@property
@tf_required
def n_gpu(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends(self, ["tf"])
warnings.warn(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
FutureWarning,
......
......@@ -163,9 +163,7 @@ from .import_utils import (
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
tf_required,
torch_only_method,
torch_required,
torch_version,
)
......
......@@ -22,7 +22,7 @@ import shutil
import sys
import warnings
from collections import OrderedDict
from functools import lru_cache, wraps
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any
......@@ -1039,30 +1039,6 @@ class DummyObject(type):
requires_backends(cls, cls._backends)
def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
return wrapper
def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
def is_torch_fx_proxy(x):
if is_torch_fx_available():
import torch.fx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment