Unverified Commit 0f08cd20 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Smarter check for `is_tensor` (#25871)



* Smarter check for

* Use protected functions

* Do others too

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Address review comments

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 3fb1535b
...@@ -71,31 +71,64 @@ def strtobool(val): ...@@ -71,31 +71,64 @@ def strtobool(val):
raise ValueError(f"invalid truth value {val!r}") raise ValueError(f"invalid truth value {val!r}")
def infer_framework_from_repr(x):
"""
Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
frameworks in a smart order, without the need to import the frameworks).
"""
representation = repr(x)
if representation.startswith("tensor"):
return "pt"
elif "tf.Tensor" in representation:
return "tf"
elif representation.startswith("Array"):
return "jax"
elif representation.startswith("array"):
return "np"
def _get_frameworks_and_test_func(x):
"""
Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
we can guess from the repr first, then Numpy, then the others.
"""
framework_to_test = {
"pt": is_torch_tensor,
"tf": is_tf_tensor,
"jax": is_jax_tensor,
"np": is_numpy_array,
}
preferred_framework = infer_framework_from_repr(x)
# We will test this one first, then numpy, then the others.
frameworks = [] if preferred_framework is None else [preferred_framework]
if preferred_framework != "np":
frameworks.append("np")
frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
return {f: framework_to_test[f] for f in frameworks}
def is_tensor(x): def is_tensor(x):
""" """
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`. Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray` in the order
defined by `infer_framework_from_repr`
""" """
if is_torch_fx_proxy(x): # This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(x)
for test_func in framework_to_test_func.values():
if test_func(x):
return True return True
if is_torch_available():
import torch
if isinstance(x, torch.Tensor): # Tracers
return True if is_torch_fx_proxy(x):
if is_tf_available():
import tensorflow as tf
if isinstance(x, tf.Tensor):
return True return True
if is_flax_available(): if is_flax_available():
import jax.numpy as jnp
from jax.core import Tracer from jax.core import Tracer
if isinstance(x, (jnp.ndarray, Tracer)): if isinstance(x, Tracer):
return True return True
return isinstance(x, np.ndarray) return False
def _is_numpy(x): def _is_numpy(x):
...@@ -200,17 +233,27 @@ def to_py_obj(obj): ...@@ -200,17 +233,27 @@ def to_py_obj(obj):
""" """
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
""" """
framework_to_py_obj = {
"pt": lambda obj: obj.detach().cpu().tolist(),
"tf": lambda obj: obj.numpy().tolist(),
"jax": lambda obj: np.asarray(obj).tolist(),
"np": lambda obj: obj.tolist(),
}
if isinstance(obj, (dict, UserDict)): if isinstance(obj, (dict, UserDict)):
return {k: to_py_obj(v) for k, v in obj.items()} return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj] return [to_py_obj(o) for o in obj]
elif is_tf_tensor(obj):
return obj.numpy().tolist() # This gives us a smart order to test the frameworks with the corresponding tests.
elif is_torch_tensor(obj): framework_to_test_func = _get_frameworks_and_test_func(obj)
return obj.detach().cpu().tolist() for framework, test_func in framework_to_test_func.items():
elif is_jax_tensor(obj): if test_func(obj):
return np.asarray(obj).tolist() return framework_to_py_obj[framework](obj)
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
# tolist also works on 0d np arrays
if isinstance(obj, np.number):
return obj.tolist() return obj.tolist()
else: else:
return obj return obj
...@@ -220,17 +263,25 @@ def to_numpy(obj): ...@@ -220,17 +263,25 @@ def to_numpy(obj):
""" """
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array. Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
""" """
framework_to_numpy = {
"pt": lambda obj: obj.detach().cpu().numpy(),
"tf": lambda obj: obj.numpy(),
"jax": lambda obj: np.asarray(obj),
"np": lambda obj: obj,
}
if isinstance(obj, (dict, UserDict)): if isinstance(obj, (dict, UserDict)):
return {k: to_numpy(v) for k, v in obj.items()} return {k: to_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
return np.array(obj) return np.array(obj)
elif is_tf_tensor(obj):
return obj.numpy() # This gives us a smart order to test the frameworks with the corresponding tests.
elif is_torch_tensor(obj): framework_to_test_func = _get_frameworks_and_test_func(obj)
return obj.detach().cpu().numpy() for framework, test_func in framework_to_test_func.items():
elif is_jax_tensor(obj): if test_func(obj):
return np.asarray(obj) return framework_to_numpy[framework](obj)
else:
return obj return obj
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment