Unverified Commit 06886d5a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Only resize embeddings when necessary (#20043)

* Only resize embeddings when necessary

* Add comment
parent 9080607b
...@@ -387,7 +387,11 @@ def main(): ...@@ -387,7 +387,11 @@ def main():
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
......
...@@ -378,7 +378,11 @@ def main(): ...@@ -378,7 +378,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
......
...@@ -389,7 +389,11 @@ def main(): ...@@ -389,7 +389,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config) model = AutoModelForMaskedLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
......
...@@ -383,7 +383,11 @@ def main(): ...@@ -383,7 +383,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config) model = AutoModelForMaskedLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
......
...@@ -376,7 +376,11 @@ def main(): ...@@ -376,7 +376,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
......
...@@ -398,7 +398,11 @@ def main(): ...@@ -398,7 +398,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForMultipleChoice.from_config(config) model = AutoModelForMultipleChoice.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
......
...@@ -380,7 +380,11 @@ def main(): ...@@ -380,7 +380,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......
...@@ -422,7 +422,11 @@ def main(): ...@@ -422,7 +422,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer): if isinstance(tokenizer, MBartTokenizer):
......
...@@ -439,7 +439,11 @@ def main(): ...@@ -439,7 +439,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForSeq2SeqLM.from_config(config) model = AutoModelForSeq2SeqLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......
...@@ -414,7 +414,13 @@ def main(): ...@@ -414,7 +414,13 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForTokenClassification.from_config(config) model = AutoModelForTokenClassification.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Model has labels -> use them. # Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
......
...@@ -380,7 +380,11 @@ def main(): ...@@ -380,7 +380,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Set decoder_start_token_id # Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
......
...@@ -411,7 +411,11 @@ def main(): ...@@ -411,7 +411,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = AutoModelForSeq2SeqLM.from_config(config) model = AutoModelForSeq2SeqLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Set decoder_start_token_id # Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
......
...@@ -473,7 +473,11 @@ def main(): ...@@ -473,7 +473,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = TFAutoModelForCausalLM.from_config(config) model = TFAutoModelForCausalLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
# region TF Dataset preparation # region TF Dataset preparation
......
...@@ -489,7 +489,11 @@ def main(): ...@@ -489,7 +489,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = TFAutoModelForMaskedLM.from_config(config) model = TFAutoModelForMaskedLM.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
# region TF Dataset preparation # region TF Dataset preparation
......
...@@ -516,7 +516,11 @@ def main(): ...@@ -516,7 +516,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
# region Prepare TF Dataset objects # region Prepare TF Dataset objects
......
...@@ -385,7 +385,11 @@ def main(): ...@@ -385,7 +385,11 @@ def main():
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = TFAutoModelForTokenClassification.from_config(config) model = TFAutoModelForTokenClassification.from_config(config)
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# endregion # endregion
# region Create TF datasets # region Create TF datasets
......
...@@ -469,7 +469,11 @@ def main(): ...@@ -469,7 +469,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
model.resize_token_embeddings(len(tokenizer)) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
model.config.forced_bos_token_id = forced_bos_token_id model.config.forced_bos_token_id = forced_bos_token_id
# endregion # endregion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment