Commit f2f32940 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Fix input embeddings

parent bdfe21ab
...@@ -49,6 +49,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -49,6 +49,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
embedding_size=16,
hidden_size=36, hidden_size=36,
num_hidden_layers=6, num_hidden_layers=6,
num_hidden_groups=6, num_hidden_groups=6,
...@@ -73,6 +74,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -73,6 +74,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
...@@ -54,6 +54,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -54,6 +54,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
embedding_size=16,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
...@@ -77,6 +78,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -77,6 +78,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
...@@ -426,9 +426,10 @@ class TFCommonTestCases: ...@@ -426,9 +426,10 @@ class TFCommonTestCases:
try: try:
x = wte([input_ids], mode="embedding") x = wte([input_ids], mode="embedding")
except: except:
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) if hasattr(self.model_tester, "embedding_size"):
# ^^ In our TF models, the input_embeddings can take slightly different forms, x = tf.ones(input_ids.shape + [model.config.embedding_size], dtype=tf.dtypes.float32)
# so we try two of them and fall back to just synthetically creating a dummy tensor of ones. else:
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
inputs_dict["inputs_embeds"] = x inputs_dict["inputs_embeds"] = x
outputs = model(inputs_dict) outputs = model(inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment