Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cf62bdc9
Commit
cf62bdc9
authored
Nov 26, 2019
by
Julien Chaumond
Browse files
Improve test protocol for inputs_embeds in TF
cc @lysandrejik
parent
b6321452
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+9
-4
No files found.
transformers/tests/modeling_tf_common_test.py
View file @
cf62bdc9
...
...
@@ -426,10 +426,15 @@ class TFCommonTestCases:
try
:
x
=
wte
([
input_ids
],
mode
=
"embedding"
)
except
:
if
hasattr
(
self
.
model_tester
,
"embedding_size"
):
x
=
tf
.
ones
(
input_ids
.
shape
+
[
model
.
config
.
embedding_size
],
dtype
=
tf
.
dtypes
.
float32
)
else
:
x
=
tf
.
ones
(
input_ids
.
shape
+
[
self
.
model_tester
.
hidden_size
],
dtype
=
tf
.
dtypes
.
float32
)
x
=
wte
([
input_ids
,
None
,
None
,
None
],
mode
=
"embedding"
)
# ^^ In our TF models, the input_embeddings can take slightly different forms,
# so we try a few of them.
# We used to fall back to just synthetically creating a dummy tensor of ones:
#
# if hasattr(self.model_tester, "embedding_size"):
# x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32)
# else:
# x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
inputs_dict
[
"inputs_embeds"
]
=
x
outputs
=
model
(
inputs_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment