"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "280c757f6cf53a9c2857d8273b9fdfdf3372971d"
Unverified Commit f89f16a5 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Re-add support for single url files in objects download (#19014)

parent ad5045e3
......@@ -32,7 +32,9 @@ from .utils import (
PushToHubMixin,
cached_file,
copy_func,
download_url,
extract_commit_hash,
is_remote_url,
is_torch_available,
logging,
)
......@@ -592,9 +594,12 @@ class PretrainedConfig(PushToHubMixin):
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Soecial case when pretrained_model_name_or_path is a local file
# Special case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
configuration_file = pretrained_model_name_or_path
resolved_config_file = download_url(pretrained_model_name_or_path)
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
......
......@@ -31,8 +31,10 @@ from .utils import (
TensorType,
cached_file,
copy_func,
download_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
logging,
......@@ -386,6 +388,9 @@ class FeatureExtractionMixin(PushToHubMixin):
if os.path.isfile(pretrained_model_name_or_path):
resolved_feature_extractor_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
else:
feature_extractor_file = FEATURE_EXTRACTOR_NAME
try:
......
......@@ -47,8 +47,10 @@ from .utils import (
add_start_docstrings_to_model_forward,
cached_file,
copy_func,
download_url,
has_file,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
......@@ -677,6 +679,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
try:
......
......@@ -54,9 +54,11 @@ from .utils import (
ModelOutput,
PushToHubMixin,
cached_file,
download_url,
find_labels,
has_file,
is_offline_mode,
is_remote_url,
logging,
requires_backends,
working_or_temp_dir,
......@@ -2345,6 +2347,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
......
......@@ -59,10 +59,12 @@ from .utils import (
PushToHubMixin,
cached_file,
copy_func,
download_url,
has_file,
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
......@@ -1998,6 +2000,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if from_tf:
......
......@@ -42,9 +42,11 @@ from .utils import (
add_end_docstrings,
cached_file,
copy_func,
download_url,
extract_commit_hash,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
......@@ -1680,6 +1682,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
FutureWarning,
)
file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
......@@ -1723,6 +1726,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
elif is_remote_url(file_path):
resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
else:
resolved_vocab_files[file_id] = cached_file(
pretrained_model_name_or_path,
......
......@@ -63,6 +63,7 @@ from .hub import (
cached_file,
default_cache_path,
define_sagemaker_information,
download_url,
extract_commit_hash,
get_cached_models,
get_file_from_repo,
......@@ -70,6 +71,7 @@ from .hub import (
has_file,
http_user_agent,
is_offline_mode,
is_remote_url,
move_cache,
send_example_telemetry,
)
......
......@@ -19,10 +19,12 @@ import os
import re
import shutil
import sys
import tempfile
import traceback
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from uuid import uuid4
import huggingface_hub
......@@ -37,7 +39,7 @@ from huggingface_hub import (
whoami,
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import (
EntryNotFoundError,
LocalEntryNotFoundError,
......@@ -124,6 +126,11 @@ HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/tele
_CACHED_NO_EXIST = object()
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
......@@ -541,6 +548,32 @@ def get_file_from_repo(
)
def download_url(url, proxies=None):
"""
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
for deprecated behavior allowing to download config/models with a single url instead of using the Hub.
Args:
url (`str`): The url of the file to download.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
Returns:
`str`: The location of the temporary file where the url was downloaded.
"""
warnings.warn(
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
" multiple processes (each process will download the file in a different temporary file)."
)
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(url, f, proxies=proxies)
return tmp_file
def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment