Unverified Commit 06886d5a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Only resize embeddings when necessary (#20043)

* Only resize embeddings when necessary

* Add comment
parent 9080607b
......@@ -387,6 +387,10 @@ def main():
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
......
......@@ -378,6 +378,10 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
......
......@@ -389,6 +389,10 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
......
......@@ -383,6 +383,10 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
......
......@@ -376,6 +376,10 @@ def main():
logger.info("Training new model from scratch")
model = XLNetLMHeadModel(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
......
......@@ -398,6 +398,10 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForMultipleChoice.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
......
......@@ -380,6 +380,10 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None:
......
......@@ -422,6 +422,10 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
......
......@@ -439,6 +439,10 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForSeq2SeqLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......
......@@ -414,6 +414,12 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForTokenClassification.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Model has labels -> use them.
......
......@@ -380,6 +380,10 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Set decoder_start_token_id
......
......@@ -411,6 +411,10 @@ def main():
logger.info("Training new model from scratch")
model = AutoModelForSeq2SeqLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Set decoder_start_token_id
......
......@@ -473,6 +473,10 @@ def main():
logger.info("Training new model from scratch")
model = TFAutoModelForCausalLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion
......
......@@ -489,6 +489,10 @@ def main():
logger.info("Training new model from scratch")
model = TFAutoModelForMaskedLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion
......
......@@ -516,6 +516,10 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion
......
......@@ -385,6 +385,10 @@ def main():
logger.info("Training new model from scratch")
model = TFAutoModelForTokenClassification.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion
......
......@@ -469,6 +469,10 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
model.config.forced_bos_token_id = forced_bos_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment