Unverified Commit 1d3a1cc4 authored by Matt's avatar Matt Committed by GitHub
Browse files

Add check for different embedding types in examples (#21881)

* Add check for different embedding types in examples

* Correctly update summarization example
parent 53735d7c
...@@ -475,7 +475,15 @@ def main(): ...@@ -475,7 +475,15 @@ def main():
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0] embeddings = model.get_input_embeddings()
# Matt: This is a temporary workaround as we transition our models to exclusively using Keras embeddings.
# As soon as the transition is complete, all embeddings should be keras.Embeddings layers, and
# the weights will always be in embeddings.embeddings.
if hasattr(embeddings, "embeddings"):
embedding_size = embeddings.embeddings.shape[0]
else:
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embedding_size: if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
......
...@@ -491,7 +491,15 @@ def main(): ...@@ -491,7 +491,15 @@ def main():
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0] embeddings = model.get_input_embeddings()
# Matt: This is a temporary workaround as we transition our models to exclusively using Keras embeddings.
# As soon as the transition is complete, all embeddings should be keras.Embeddings layers, and
# the weights will always be in embeddings.embeddings.
if hasattr(embeddings, "embeddings"):
embedding_size = embeddings.embeddings.shape[0]
else:
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embedding_size: if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
......
...@@ -518,7 +518,15 @@ def main(): ...@@ -518,7 +518,15 @@ def main():
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0] embeddings = model.get_input_embeddings()
# Matt: This is a temporary workaround as we transition our models to exclusively using Keras embeddings.
# As soon as the transition is complete, all embeddings should be keras.Embeddings layers, and
# the weights will always be in embeddings.embeddings.
if hasattr(embeddings, "embeddings"):
embedding_size = embeddings.embeddings.shape[0]
else:
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embedding_size: if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
......
...@@ -387,7 +387,15 @@ def main(): ...@@ -387,7 +387,15 @@ def main():
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0] embeddings = model.get_input_embeddings()
# Matt: This is a temporary workaround as we transition our models to exclusively using Keras embeddings.
# As soon as the transition is complete, all embeddings should be keras.Embeddings layers, and
# the weights will always be in embeddings.embeddings.
if hasattr(embeddings, "embeddings"):
embedding_size = embeddings.embeddings.shape[0]
else:
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embedding_size: if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
......
...@@ -471,9 +471,18 @@ def main(): ...@@ -471,9 +471,18 @@ def main():
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0] embeddings = model.get_input_embeddings()
# Matt: This is a temporary workaround as we transition our models to exclusively using Keras embeddings.
# As soon as the transition is complete, all embeddings should be keras.Embeddings layers, and
# the weights will always be in embeddings.embeddings.
if hasattr(embeddings, "embeddings"):
embedding_size = embeddings.embeddings.shape[0]
else:
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embedding_size: if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
model.config.forced_bos_token_id = forced_bos_token_id model.config.forced_bos_token_id = forced_bos_token_id
# endregion # endregion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment