Unverified Commit 377cdded authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up hub (#18497)

* Clean up utils.hub

* Remove imports

* More fixes

* Last fix
parent a4562552
...@@ -441,7 +441,6 @@ _import_structure = { ...@@ -441,7 +441,6 @@ _import_structure = {
"TensorType", "TensorType",
"add_end_docstrings", "add_end_docstrings",
"add_start_docstrings", "add_start_docstrings",
"cached_path",
"is_apex_available", "is_apex_available",
"is_datasets_available", "is_datasets_available",
"is_faiss_available", "is_faiss_available",
...@@ -3214,7 +3213,6 @@ if TYPE_CHECKING: ...@@ -3214,7 +3213,6 @@ if TYPE_CHECKING:
TensorType, TensorType,
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
cached_path,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_faiss_available, is_faiss_available,
......
...@@ -38,7 +38,6 @@ from . import ( ...@@ -38,7 +38,6 @@ from . import (
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -91,11 +90,10 @@ from . import ( ...@@ -91,11 +90,10 @@ from . import (
XLMConfig, XLMConfig,
XLMRobertaConfig, XLMRobertaConfig,
XLNetConfig, XLNetConfig,
cached_path,
is_torch_available, is_torch_available,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
) )
from .utils import hf_bucket_url, logging from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
if is_torch_available(): if is_torch_available():
...@@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf( ...@@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf(
# Initialise TF model # Initialise TF model
if config_file in aws_config_map: if config_file in aws_config_map:
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models) config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
config = config_class.from_json_file(config_file) config = config_class.from_json_file(config_file)
config.output_hidden_states = True config.output_hidden_states = True
config.output_attentions = True config.output_attentions = True
...@@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf( ...@@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf(
# Load weights from tf checkpoint # Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_config_map.keys(): if pytorch_checkpoint_path in aws_config_map.keys():
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME) pytorch_checkpoint_path = cached_file(
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models) pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model: # Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
...@@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf( ...@@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf(
print("-" * 100) print("-" * 100)
if config_shortcut_name in aws_config_map: if config_shortcut_name in aws_config_map:
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models) config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
else: else:
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models) config_file = config_shortcut_name
if model_shortcut_name in aws_model_maps: if model_shortcut_name in aws_model_maps:
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models) model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
else: else:
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) model_file = model_shortcut_name
if os.path.isfile(model_shortcut_name): if os.path.isfile(model_shortcut_name):
model_shortcut_name = "converted_model" model_shortcut_name = "converted_model"
......
...@@ -24,14 +24,7 @@ from typing import Dict, Optional, Union ...@@ -24,14 +24,7 @@ from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info from huggingface_hub import HfFolder, model_info
from .utils import ( from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -219,18 +212,15 @@ def get_cached_module_file( ...@@ -219,18 +212,15 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
submodule = "local" submodule = "local"
else: else:
module_file_or_url = hf_bucket_url(
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep) submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_module_file = cached_path( resolved_module_file = cached_file(
module_file_or_url, pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
......
...@@ -69,20 +69,14 @@ from .utils import ( ...@@ -69,20 +69,14 @@ from .utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
cached_path,
cached_property, cached_property,
copy_func, copy_func,
default_cache_path, default_cache_path,
define_sagemaker_information, define_sagemaker_information,
filename_to_url,
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_from_cache,
get_full_repo_name, get_full_repo_name,
get_list_of_files,
has_file, has_file,
hf_bucket_url,
http_get,
http_user_agent, http_user_agent,
is_apex_available, is_apex_available,
is_coloredlogs_available, is_coloredlogs_available,
...@@ -94,7 +88,6 @@ from .utils import ( ...@@ -94,7 +88,6 @@ from .utils import (
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_librosa_available, is_librosa_available,
is_local_clone,
is_offline_mode, is_offline_mode,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
...@@ -105,7 +98,6 @@ from .utils import ( ...@@ -105,7 +98,6 @@ from .utils import (
is_pyctcdecode_available, is_pyctcdecode_available,
is_pytesseract_available, is_pytesseract_available,
is_pytorch_quantization_available, is_pytorch_quantization_available,
is_remote_url,
is_rjieba_available, is_rjieba_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
...@@ -141,5 +133,4 @@ from .utils import ( ...@@ -141,5 +133,4 @@ from .utils import (
torch_only_method, torch_only_method,
torch_required, torch_required,
torch_version, torch_version,
url_to_filename,
) )
...@@ -43,15 +43,10 @@ from .models.auto.modeling_auto import ( ...@@ -43,15 +43,10 @@ from .models.auto.modeling_auto import (
) )
from .training_args import ParallelMode from .training_args import ParallelMode
from .utils import ( from .utils import (
CONFIG_NAME,
MODEL_CARD_NAME, MODEL_CARD_NAME,
TF2_WEIGHTS_NAME, cached_file,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_datasets_available, is_datasets_available,
is_offline_mode, is_offline_mode,
is_remote_url,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
...@@ -153,11 +148,6 @@ class ModelCard: ...@@ -153,11 +148,6 @@ class ModelCard:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
find_from_standard_name: (*optional*) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
with our standard modelcard filename. Can be used to directly feed a model/config url and access the
colocated modelcard.
return_unused_kwargs: (*optional*) bool: return_unused_kwargs: (*optional*) bool:
- If False, then this function returns just the final model card object. - If False, then this function returns just the final model card object.
...@@ -168,21 +158,15 @@ class ModelCard: ...@@ -168,21 +158,15 @@ class ModelCard:
Examples: Examples:
```python ```python
modelcard = ModelCard.from_pretrained( # Download model card from huggingface.co and cache.
"bert-base-uncased" modelcard = ModelCard.from_pretrained("bert-base-uncased")
) # Download model card from huggingface.co and cache. # Model card was saved using *save_pretrained('./test/saved_model/')*
modelcard = ModelCard.from_pretrained( modelcard = ModelCard.from_pretrained("./test/saved_model/")
"./test/saved_model/"
) # E.g. model card was saved using *save_pretrained('./test/saved_model/')*
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json") modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False) modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
```""" ```"""
# This imports every model so let's do it dynamically here.
from transformers.models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
...@@ -190,37 +174,30 @@ class ModelCard: ...@@ -190,37 +174,30 @@ class ModelCard:
if from_pipeline is not None: if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline user_agent["using_pipeline"] = from_pipeline
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: is_local = os.path.isdir(pretrained_model_name_or_path)
# For simplicity we use the same pretrained url than the configuration files if os.path.isfile(pretrained_model_name_or_path):
# but with a different suffix (modelcard.json). This suffix is replaced below. resolved_model_card_file = pretrained_model_name_or_path
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] is_local = True
elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else: else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None) try:
# Load from URL or cache if already cached
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: resolved_model_card_file = cached_file(
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME) pretrained_model_name_or_path,
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME) filename=MODEL_CARD_NAME,
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME) cache_dir=cache_dir,
proxies=proxies,
try: user_agent=user_agent,
# Load from URL or cache if already cached )
resolved_model_card_file = cached_path( if is_local:
model_card_file, cache_dir=cache_dir, proxies=proxies, user_agent=user_agent logger.info(f"loading model card file {resolved_model_card_file}")
) else:
if resolved_model_card_file == model_card_file: logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
logger.info(f"loading model card file {model_card_file}") # Load model card
else: modelcard = cls.from_json_file(resolved_model_card_file)
logger.info(f"loading model card file {model_card_file} from cache at {resolved_model_card_file}")
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
except (EnvironmentError, json.JSONDecodeError): except (EnvironmentError, json.JSONDecodeError):
# We fall back on creating an empty model card # We fall back on creating an empty model card
modelcard = cls() modelcard = cls()
# Update model card with kwargs if needed # Update model card with kwargs if needed
to_remove = [] to_remove = []
......
...@@ -2156,7 +2156,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2156,7 +2156,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
mirror = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
load_weight_prefix = kwargs.pop("load_weight_prefix", None) load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
...@@ -2270,7 +2270,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2270,7 +2270,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# message. # message.
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"mirror": mirror,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "use_auth_token": use_auth_token,
} }
...@@ -2321,7 +2320,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2321,7 +2320,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
mirror=mirror,
) )
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
......
...@@ -1784,7 +1784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1784,7 +1784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
mirror = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True) _fast_init = kwargs.pop("_fast_init", True)
...@@ -1955,7 +1955,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1955,7 +1955,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# message. # message.
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"mirror": mirror,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "use_auth_token": use_auth_token,
} }
...@@ -2012,7 +2011,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2012,7 +2011,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
mirror=mirror,
subfolder=subfolder, subfolder=subfolder,
) )
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, logging, requires_backends from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer from .tokenization_rag import RagTokenizer
...@@ -111,22 +111,21 @@ class LegacyIndex(Index): ...@@ -111,22 +111,21 @@ class LegacyIndex(Index):
self._index_initialized = False self._index_initialized = False
def _resolve_path(self, index_path, filename): def _resolve_path(self, index_path, filename):
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid `index_path`." is_local = os.path.isdir(index_path)
archive_file = os.path.join(index_path, filename)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_archive_file = cached_path(archive_file) resolved_archive_file = cached_file(index_path, filename)
except EnvironmentError: except EnvironmentError:
msg = ( msg = (
f"Can't load '{archive_file}'. Make sure that:\n\n" f"Can't load '{filename}'. Make sure that:\n\n"
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n" f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n" f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
) )
raise EnvironmentError(msg) raise EnvironmentError(msg)
if resolved_archive_file == archive_file: if is_local:
logger.info(f"loading file {archive_file}") logger.info(f"loading file {resolved_archive_file}")
else: else:
logger.info(f"loading file {archive_file} from cache at {resolved_archive_file}") logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
return resolved_archive_file return resolved_archive_file
def _load_passages(self): def _load_passages(self):
......
...@@ -29,7 +29,7 @@ import numpy as np ...@@ -29,7 +29,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import ( from ...utils import (
cached_path, cached_file,
is_sacremoses_available, is_sacremoses_available,
is_torch_available, is_torch_available,
logging, logging,
...@@ -681,24 +681,21 @@ class TransfoXLCorpus(object): ...@@ -681,24 +681,21 @@ class TransfoXLCorpus(object):
Instantiate a pre-processed corpus. Instantiate a pre-processed corpus.
""" """
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: is_local = os.path.isdir(pretrained_model_name_or_path)
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
logger.error( logger.error(
f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list" f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'" f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
f" was a path or url but couldn't find files {corpus_file} at this path or url." f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url."
) )
return None return None
if resolved_corpus_file == corpus_file: if is_local:
logger.info(f"loading corpus file {corpus_file}") logger.info(f"loading corpus file {resolved_corpus_file}")
else: else:
logger.info(f"loading corpus file {corpus_file} from cache at {resolved_corpus_file}") logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}")
# Instantiate tokenizer. # Instantiate tokenizer.
corpus = cls(*inputs, **kwargs) corpus = cls(*inputs, **kwargs)
......
...@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union ...@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from numpy import isin from numpy import isin
from huggingface_hub.file_download import http_get
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import get_class_from_dynamic_module from ..dynamic_module_utils import get_class_from_dynamic_module
from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..feature_extraction_utils import PreTrainedFeatureExtractor
...@@ -33,7 +35,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut ...@@ -33,7 +35,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils_fast import PreTrainedTokenizerFast from ..tokenization_utils_fast import PreTrainedTokenizerFast
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, http_get, is_tf_available, is_torch_available, logging from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, is_tf_available, is_torch_available, logging
from .audio_classification import AudioClassificationPipeline from .audio_classification import AudioClassificationPipeline
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
from .base import ( from .base import (
......
...@@ -61,25 +61,16 @@ from .hub import ( ...@@ -61,25 +61,16 @@ from .hub import (
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
cached_file, cached_file,
cached_path,
default_cache_path, default_cache_path,
define_sagemaker_information, define_sagemaker_information,
filename_to_url,
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_from_cache,
get_full_repo_name, get_full_repo_name,
get_list_of_files,
has_file, has_file,
hf_bucket_url,
http_get,
http_user_agent, http_user_agent,
is_local_clone,
is_offline_mode, is_offline_mode,
is_remote_url,
move_cache, move_cache,
send_example_telemetry, send_example_telemetry,
url_to_filename,
) )
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_AND_AUTO_VALUES,
......
This diff is collapsed.
...@@ -26,20 +26,13 @@ import transformers ...@@ -26,20 +26,13 @@ import transformers
from transformers import * # noqa F406 from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ContextManagers, ContextManagers,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url,
find_labels, find_labels,
get_file_from_repo, get_file_from_repo,
get_from_cache,
has_file, has_file,
hf_bucket_url,
is_flax_available, is_flax_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
...@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase): ...@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
class GetFromCacheTests(unittest.TestCase): class GetFromCacheTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = get_from_cache(url)
def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_model_not_found_not_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
_ = get_from_cache(url)
@unittest.skip("No authentication when testing against prod")
def test_model_not_found_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
_ = get_from_cache(url, use_auth_token="hf_sometoken")
# ^ TODO - if we decide to unskip this: use a real / functional token
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_standard_object(self):
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_standard_object_rev(self):
# Same object, but different revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def test_lfs_object(self):
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
def test_has_file(self): def test_has_file(self):
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME)) self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
......
...@@ -614,7 +614,6 @@ UNDOCUMENTED_OBJECTS = [ ...@@ -614,7 +614,6 @@ UNDOCUMENTED_OBJECTS = [
"absl", # External module "absl", # External module
"add_end_docstrings", # Internal, should never have been in the main init. "add_end_docstrings", # Internal, should never have been in the main init.
"add_start_docstrings", # Internal, should never have been in the main init. "add_start_docstrings", # Internal, should never have been in the main init.
"cached_path", # Internal used for downloading models.
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights "convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
"logger", # Internal logger "logger", # Internal logger
"logging", # External module "logging", # External module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment