"vscode:/vscode.git/clone" did not exist on "46ef646016eab3a8bf3c4f075c6982888c439022"
Commit cf62bdc9 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Improve test protocol for inputs_embeds in TF

cc @lysandrejik
parent b6321452
......@@ -426,10 +426,15 @@ class TFCommonTestCases:
try:
x = wte([input_ids], mode="embedding")
except:
if hasattr(self.model_tester, "embedding_size"):
x = tf.ones(input_ids.shape + [model.config.embedding_size], dtype=tf.dtypes.float32)
else:
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
x = wte([input_ids, None, None, None], mode="embedding")
# ^^ In our TF models, the input_embeddings can take slightly different forms,
# so we try a few of them.
# We used to fall back to just synthetically creating a dummy tensor of ones:
#
# if hasattr(self.model_tester, "embedding_size"):
# x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32)
# else:
# x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
inputs_dict["inputs_embeds"] = x
outputs = model(inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment