"...resnet50_tensorflow.git" did not exist on "992a5bfaa9ee90b8cd5ffc5f69fc32f34669ebff"
Unverified Commit 919a964b authored by Sergio Valcarcel Macua's avatar Sergio Valcarcel Macua Committed by GitHub
Browse files

Include Keras tensor in the allowed types (#14155)



* Include KerasTensor in allowed types

- This allows propagating symbolic tensors through TFBert models and layers' call(),
  which allows converting the subclass models to functional models.

* Style pass
Co-authored-by: default avatarSergio Valcarcel Macua <sergiov@graphcore.ai>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent f5ed19f5
...@@ -27,6 +27,7 @@ import numpy as np ...@@ -27,6 +27,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -52,7 +53,15 @@ logger = logging.get_logger(__name__) ...@@ -52,7 +53,15 @@ logger = logging.get_logger(__name__)
tf_logger = tf.get_logger() tf_logger = tf.get_logger()
TFModelInputType = Union[ TFModelInputType = Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor List[tf.Tensor],
List[np.ndarray],
List[KerasTensor],
Dict[str, tf.Tensor],
Dict[str, np.ndarray],
Dict[str, KerasTensor],
tf.Tensor,
np.ndarray,
KerasTensor,
] ]
...@@ -348,7 +357,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -348,7 +357,7 @@ def input_processing(func, config, input_ids, **kwargs):
signature.pop("self", None) signature.pop("self", None)
parameter_names = list(signature.keys()) parameter_names = list(signature.keys())
output = {} output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
if "inputs" in kwargs["kwargs_call"]: if "inputs" in kwargs["kwargs_call"]:
warnings.warn( warnings.warn(
...@@ -432,7 +441,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -432,7 +441,7 @@ def input_processing(func, config, input_ids, **kwargs):
else: else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else: else:
if isinstance(input_ids, tf.Tensor) or input_ids is None: if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None:
output[parameter_names[0]] = input_ids output[parameter_names[0]] = input_ids
else: else:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment