Unverified Commit 7d9a33fb authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF Bert inference - support `np.ndarray` optional arguments (#15074)

* TF Bert inference - support np.ndarray optional arguments

* apply np input tests to all TF architectures
parent 4663c609
......@@ -1941,16 +1941,19 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
def shape_list(tensor: tf.Tensor) -> List[int]:
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (`tf.Tensor`): The tensor we want the shape of.
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns:
`List[int]`: The shape of the tensor as a list.
"""
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
......
......@@ -846,7 +846,9 @@ class TFModelTesterMixin:
inputs = self._prepare_for_class(inputs_dict, model_class)
inputs_np = prepare_numpy_arrays(inputs)
model(inputs_np)
output_for_dict_input = model(inputs_np)
output_for_kw_input = model(**inputs_np)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_resize_token_embeddings(self):
if not self.test_resize_embeddings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment