Unverified Commit 9151e649 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make public versions of private tensor utils (#19775)

* Make public versions of private utils

* I need sleep
parent 3aaabaa2
...@@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union ...@@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
from .utils.generic import _is_tensorflow, _is_torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin): ...@@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
first_element = required_input[index][0] first_element = required_input[index][0]
if return_tensors is None: if return_tensors is None:
if is_tf_available() and _is_tensorflow(first_element): if is_tf_tensor(first_element):
return_tensors = "tf" return_tensors = "tf"
elif is_torch_available() and _is_torch(first_element): elif is_torch_tensor(first_element):
return_tensors = "pt" return_tensors = "pt"
elif isinstance(first_element, (int, float, list, tuple, np.ndarray)): elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
return_tensors = "np" return_tensors = "np"
......
...@@ -33,14 +33,16 @@ from .utils import ( ...@@ -33,14 +33,16 @@ from .utils import (
copy_func, copy_func,
download_url, download_url,
is_flax_available, is_flax_available,
is_jax_tensor,
is_numpy_array,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_device,
logging, logging,
torch_required, torch_required,
) )
from .utils.generic import _is_jax, _is_numpy, _is_torch_device
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -150,10 +152,10 @@ class BatchFeature(UserDict): ...@@ -150,10 +152,10 @@ class BatchFeature(UserDict):
import jax.numpy as jnp # noqa: F811 import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array as_tensor = jnp.array
is_tensor = _is_jax is_tensor = is_jax_tensor
else: else:
as_tensor = np.asarray as_tensor = np.asarray
is_tensor = _is_numpy is_tensor = is_numpy_array
# Do the tensor conversion in batch # Do the tensor conversion in batch
for key, value in self.items(): for key, value in self.items():
...@@ -188,7 +190,7 @@ class BatchFeature(UserDict): ...@@ -188,7 +190,7 @@ class BatchFeature(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor # into a HalfTensor
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()} self.data = {k: v.to(device=device) for k, v in self.data.items()}
else: else:
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
......
...@@ -21,14 +21,21 @@ from packaging import version ...@@ -21,14 +21,21 @@ from packaging import version
import requests import requests
from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available from .utils import (
ExplicitEnum,
is_jax_tensor,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
is_vision_available,
to_numpy,
)
from .utils.constants import ( # noqa: F401 from .utils.constants import ( # noqa: F401
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD, IMAGENET_STANDARD_STD,
) )
from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy
if is_vision_available(): if is_vision_available():
...@@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum): ...@@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum):
LAST = "channels_last" LAST = "channels_last"
def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False
def is_tf_tensor(obj):
return _is_tensorflow(obj) if is_tf_available() else False
def is_jax_tensor(obj):
return _is_jax(obj) if is_flax_available() else False
def is_valid_image(img): def is_valid_image(img):
return ( return (
isinstance(img, (PIL.Image.Image, np.ndarray)) isinstance(img, (PIL.Image.Image, np.ndarray))
......
...@@ -33,11 +33,9 @@ from ...tokenization_utils_base import ( ...@@ -33,11 +33,9 @@ from ...tokenization_utils_base import (
TextInput, TextInput,
TextInputPair, TextInputPair,
TruncationStrategy, TruncationStrategy,
_is_tensorflow,
_is_torch,
to_py_obj, to_py_obj,
) )
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer): ...@@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer):
first_element = required_input[index][0] first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)): if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element): if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element): elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray): elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors return_tensors = "np" if return_tensors is None else return_tensors
......
...@@ -37,11 +37,9 @@ from ...tokenization_utils_base import ( ...@@ -37,11 +37,9 @@ from ...tokenization_utils_base import (
TextInput, TextInput,
TextInputPair, TextInputPair,
TruncationStrategy, TruncationStrategy,
_is_tensorflow,
_is_torch,
to_py_obj, to_py_obj,
) )
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer): ...@@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer):
first_element = required_input[index][0] first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)): if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element): if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element): elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray): elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors return_tensors = "np" if return_tensors is None else return_tensors
......
...@@ -45,16 +45,20 @@ from .utils import ( ...@@ -45,16 +45,20 @@ from .utils import (
download_url, download_url,
extract_commit_hash, extract_commit_hash,
is_flax_available, is_flax_available,
is_jax_tensor,
is_numpy_array,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
is_tf_available, is_tf_available,
is_tf_tensor,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
is_torch_device,
is_torch_tensor,
logging, logging,
to_py_obj, to_py_obj,
torch_required, torch_required,
) )
from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -696,15 +700,10 @@ class BatchEncoding(UserDict): ...@@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
import jax.numpy as jnp # noqa: F811 import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array as_tensor = jnp.array
is_tensor = _is_jax is_tensor = is_jax_tensor
else: else:
as_tensor = np.asarray as_tensor = np.asarray
is_tensor = _is_numpy is_tensor = is_numpy_array
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
# f"Unable to convert output to tensors format {tensor_type}"
# )
# Do the tensor conversion in batch # Do the tensor conversion in batch
for key, value in self.items(): for key, value in self.items():
...@@ -753,7 +752,7 @@ class BatchEncoding(UserDict): ...@@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor # into a HalfTensor
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()} self.data = {k: v.to(device=device) for k, v in self.data.items()}
else: else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
...@@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
break break
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)): if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element): if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element): elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray): elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors return_tensors = "np" if return_tensors is None else return_tensors
......
...@@ -40,7 +40,12 @@ from .generic import ( ...@@ -40,7 +40,12 @@ from .generic import (
cached_property, cached_property,
find_labels, find_labels,
flatten_dict, flatten_dict,
is_jax_tensor,
is_numpy_array,
is_tensor, is_tensor,
is_tf_tensor,
is_torch_device,
is_torch_tensor,
to_numpy, to_numpy,
to_py_obj, to_py_obj,
working_or_temp_dir, working_or_temp_dir,
......
...@@ -83,30 +83,65 @@ def _is_numpy(x): ...@@ -83,30 +83,65 @@ def _is_numpy(x):
return isinstance(x, np.ndarray) return isinstance(x, np.ndarray)
def is_numpy_array(x):
"""
Tests if `x` is a numpy array or not.
"""
return _is_numpy(x)
def _is_torch(x): def _is_torch(x):
import torch import torch
return isinstance(x, torch.Tensor) return isinstance(x, torch.Tensor)
def is_torch_tensor(x):
"""
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch(x)
def _is_torch_device(x): def _is_torch_device(x):
import torch import torch
return isinstance(x, torch.device) return isinstance(x, torch.device)
def is_torch_device(x):
"""
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch_device(x)
def _is_tensorflow(x): def _is_tensorflow(x):
import tensorflow as tf import tensorflow as tf
return isinstance(x, tf.Tensor) return isinstance(x, tf.Tensor)
def is_tf_tensor(x):
"""
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
"""
return False if not is_tf_available() else _is_tensorflow(x)
def _is_jax(x): def _is_jax(x):
import jax.numpy as jnp # noqa: F811 import jax.numpy as jnp # noqa: F811
return isinstance(x, jnp.ndarray) return isinstance(x, jnp.ndarray)
def is_jax_tensor(x):
"""
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
"""
return False if not is_flax_available() else _is_jax(x)
def to_py_obj(obj): def to_py_obj(obj):
""" """
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
...@@ -115,11 +150,11 @@ def to_py_obj(obj): ...@@ -115,11 +150,11 @@ def to_py_obj(obj):
return {k: to_py_obj(v) for k, v in obj.items()} return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj] return [to_py_obj(o) for o in obj]
elif is_tf_available() and _is_tensorflow(obj): elif is_tf_tensor(obj):
return obj.numpy().tolist() return obj.numpy().tolist()
elif is_torch_available() and _is_torch(obj): elif is_torch_tensor(obj):
return obj.detach().cpu().tolist() return obj.detach().cpu().tolist()
elif is_flax_available() and _is_jax(obj): elif is_jax_tensor(obj):
return np.asarray(obj).tolist() return np.asarray(obj).tolist()
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
return obj.tolist() return obj.tolist()
...@@ -135,11 +170,11 @@ def to_numpy(obj): ...@@ -135,11 +170,11 @@ def to_numpy(obj):
return {k: to_numpy(v) for k, v in obj.items()} return {k: to_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
return np.array(obj) return np.array(obj)
elif is_tf_available() and _is_tensorflow(obj): elif is_tf_tensor(obj):
return obj.numpy() return obj.numpy()
elif is_torch_available() and _is_torch(obj): elif is_torch_tensor(obj):
return obj.detach().cpu().numpy() return obj.detach().cpu().numpy()
elif is_flax_available() and _is_jax(obj): elif is_jax_tensor(obj):
return np.asarray(obj) return np.asarray(obj)
else: else:
return obj return obj
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment