Unverified Commit e152f295 authored by Kevin H. Luu's avatar Kevin H. Luu Committed by GitHub
Browse files

[misc] Reduce number of config file requests to HuggingFace (#12797)


Signed-off-by: default avatarEC2 Default User <ec2-user@ip-172-31-20-117.us-west-2.compute.internal>
Signed-off-by: <>
Co-authored-by: default avatarEC2 Default User <ec2-user@ip-172-31-20-117.us-west-2.compute.internal>
parent c786e757
......@@ -7,7 +7,7 @@ from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download,
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
try_to_load_from_cache)
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
LocalEntryNotFoundError,
......@@ -395,7 +395,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
- dict: A dictionary containing the configuration parameters
for the Sentence Transformer BERT model.
"""
for config_name in [
sentence_transformer_config_files = [
"sentence_bert_config.json",
"sentence_roberta_config.json",
"sentence_distilbert_config.json",
......@@ -403,7 +403,17 @@ def get_sentence_transformer_tokenizer_config(model: str,
"sentence_albert_config.json",
"sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json",
]:
]
try:
# If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN)
except Exception as e:
logger.debug("Error getting repo files", e)
repo_files = []
encoder_dict = None
for config_name in sentence_transformer_config_files:
if config_name in repo_files or Path(model).exists():
encoder_dict = get_hf_file_to_dict(config_name, model, revision)
if encoder_dict:
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment