Unverified Commit 7d9a33fb authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF Bert inference - support `np.ndarray` optional arguments (#15074)

* TF Bert inference - support np.ndarray optional arguments

* apply np input tests to all TF architectures
parent 4663c609
...@@ -1941,16 +1941,19 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -1941,16 +1941,19 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output return output
def shape_list(tensor: tf.Tensor) -> List[int]: def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
""" """
Deal with dynamic shape in tensorflow cleanly. Deal with dynamic shape in tensorflow cleanly.
Args: Args:
tensor (`tf.Tensor`): The tensor we want the shape of. tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns: Returns:
`List[int]`: The shape of the tensor as a list. `List[int]`: The shape of the tensor as a list.
""" """
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor) dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None): if tensor.shape == tf.TensorShape(None):
......
...@@ -846,7 +846,9 @@ class TFModelTesterMixin: ...@@ -846,7 +846,9 @@ class TFModelTesterMixin:
inputs = self._prepare_for_class(inputs_dict, model_class) inputs = self._prepare_for_class(inputs_dict, model_class)
inputs_np = prepare_numpy_arrays(inputs) inputs_np = prepare_numpy_arrays(inputs)
model(inputs_np) output_for_dict_input = model(inputs_np)
output_for_kw_input = model(**inputs_np)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
if not self.test_resize_embeddings: if not self.test_resize_embeddings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment