Unverified Commit 676643c6 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Better logic for getting tokenizer config in AutoTokenizer (#14906)

* Better logic for getting tokenizer config in AutoTokenizer

* Remove needless import

* Remove debug statement

* Address review comments
parent f566c6e3
...@@ -50,6 +50,7 @@ from tqdm.auto import tqdm ...@@ -50,6 +50,7 @@ from tqdm.auto import tqdm
import requests import requests
from filelock import FileLock from filelock import FileLock
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from requests.exceptions import HTTPError
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
from . import __version__ from . import __version__
...@@ -2100,7 +2101,13 @@ def get_list_of_files( ...@@ -2100,7 +2101,13 @@ def get_list_of_files(
token = HfFolder.get_token() token = HfFolder.get_token()
else: else:
token = None token = None
return list_repo_files(path_or_repo, revision=revision, token=token)
try:
return list_repo_files(path_or_repo, revision=revision, token=token)
except HTTPError as e:
raise ValueError(
f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
) from e
class cached_property(property): class cached_property(property):
......
...@@ -18,11 +18,13 @@ import importlib ...@@ -18,11 +18,13 @@ import importlib
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import ( from ...file_utils import (
cached_path, cached_path,
get_list_of_files,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_sentencepiece_available, is_sentencepiece_available,
...@@ -330,6 +332,16 @@ def get_tokenizer_config( ...@@ -330,6 +332,16 @@ def get_tokenizer_config(
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
# Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier
repo_files = get_list_of_files(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]:
return {}
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
...@@ -350,7 +362,7 @@ def get_tokenizer_config( ...@@ -350,7 +362,7 @@ def get_tokenizer_config(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
except (EnvironmentError, ValueError): except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {} return {}
......
...@@ -149,7 +149,9 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -149,7 +149,9 @@ class AutoTokenizerTest(unittest.TestCase):
@require_tokenizers @require_tokenizers
def test_tokenizer_identifier_non_existent(self): def test_tokenizer_identifier_non_existent(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
with self.assertRaises(EnvironmentError): with self.assertRaisesRegex(
ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?"
):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists") _ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
def test_parents_and_children_in_mappings(self): def test_parents_and_children_in_mappings(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment