Unverified Commit 515d6a55 authored by Roy Hvaara's avatar Roy Hvaara Committed by GitHub
Browse files

[tensorflow] Add support for the `is_symbolic_tensor` predicate (#22878)



This predicate will become available in tensorflow starting with version
2.14.
Co-authored-by: default avatarRussell Power <power@google.com>
parent 5764e67c
...@@ -57,6 +57,7 @@ from .utils import ( ...@@ -57,6 +57,7 @@ from .utils import (
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
is_safetensors_available, is_safetensors_available,
is_tf_symbolic_tensor,
logging, logging,
requires_backends, requires_backends,
working_or_temp_dir, working_or_temp_dir,
...@@ -511,7 +512,7 @@ def input_processing(func, config, **kwargs): ...@@ -511,7 +512,7 @@ def input_processing(func, config, **kwargs):
if isinstance(main_input, (tuple, list)): if isinstance(main_input, (tuple, list)):
for i, input in enumerate(main_input): for i, input in enumerate(main_input):
# EagerTensors don't allow to use the .name property so we check for a real Tensor # EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor: if is_tf_symbolic_tensor(input):
# Tensor names have always the pattern `name:id` then we check only the # Tensor names have always the pattern `name:id` then we check only the
# `name` part # `name` part
tensor_name = input.name.split(":")[0] tensor_name = input.name.split(":")[0]
...@@ -572,7 +573,7 @@ def input_processing(func, config, **kwargs): ...@@ -572,7 +573,7 @@ def input_processing(func, config, **kwargs):
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception # So to respect the proper output we have to add this exception
if "args" in output: if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor: if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
tensor_name = output["args"].name.split(":")[0] tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"] output[tensor_name] = output["args"]
else: else:
......
...@@ -42,6 +42,7 @@ from .generic import ( ...@@ -42,6 +42,7 @@ from .generic import (
is_jax_tensor, is_jax_tensor,
is_numpy_array, is_numpy_array,
is_tensor, is_tensor,
is_tf_symbolic_tensor,
is_tf_tensor, is_tf_tensor,
is_torch_device, is_torch_device,
is_torch_dtype, is_torch_dtype,
......
...@@ -166,6 +166,23 @@ def is_tf_tensor(x): ...@@ -166,6 +166,23 @@ def is_tf_tensor(x):
return False if not is_tf_available() else _is_tensorflow(x) return False if not is_tf_available() else _is_tensorflow(x)
def _is_tf_symbolic_tensor(x):
import tensorflow as tf
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
return type(x) == tf.Tensor
def is_tf_symbolic_tensor(x):
"""
Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
installed.
"""
return False if not is_tf_available() else _is_tf_symbolic_tensor(x)
def _is_jax(x): def _is_jax(x):
import jax.numpy as jnp # noqa: F811 import jax.numpy as jnp # noqa: F811
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment