"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "06e782da4e58f93a60c6bedc84b5991abaae58f5"
Unverified Commit d4796653 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Examples: check `max_position_embeddings` in the translation example (#29600)

check max_position_embeddings
parent 6b660d5e
......@@ -469,6 +469,19 @@ def main():
source_lang = data_args.source_lang.split("_")[0]
target_lang = data_args.target_lang.split("_")[0]
# Check the whether the source target length fits in the model, if it has absolute positional embeddings
if (
hasattr(model.config, "max_position_embeddings")
and not hasattr(model.config, "relative_attention_max_distance")
and model.config.max_position_embeddings < data_args.max_source_length
):
raise ValueError(
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
f" `--max_source_length` to {model.config.max_position_embeddings} or using a model with larger position "
"embeddings"
)
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment