Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
377cdded
Unverified
Commit
377cdded
authored
Aug 08, 2022
by
Sylvain Gugger
Committed by
GitHub
Aug 08, 2022
Browse files
Clean up hub (#18497)
* Clean up utils.hub * Remove imports * More fixes * Last fix
parent
a4562552
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
67 additions
and
708 deletions
+67
-708
src/transformers/__init__.py
src/transformers/__init__.py
+0
-2
src/transformers/convert_pytorch_checkpoint_to_tf2.py
src/transformers/convert_pytorch_checkpoint_to_tf2.py
+9
-10
src/transformers/dynamic_module_utils.py
src/transformers/dynamic_module_utils.py
+4
-14
src/transformers/file_utils.py
src/transformers/file_utils.py
+0
-9
src/transformers/modelcard.py
src/transformers/modelcard.py
+27
-50
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+1
-3
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-3
src/transformers/models/rag/retrieval_rag.py
src/transformers/models/rag/retrieval_rag.py
+7
-8
src/transformers/models/transfo_xl/tokenization_transfo_xl.py
...transformers/models/transfo_xl/tokenization_transfo_xl.py
+7
-10
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+3
-1
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+0
-9
src/transformers/utils/hub.py
src/transformers/utils/hub.py
+8
-527
tests/utils/test_file_utils.py
tests/utils/test_file_utils.py
+0
-61
utils/check_repo.py
utils/check_repo.py
+0
-1
No files found.
src/transformers/__init__.py
View file @
377cdded
...
...
@@ -441,7 +441,6 @@ _import_structure = {
"TensorType"
,
"add_end_docstrings"
,
"add_start_docstrings"
,
"cached_path"
,
"is_apex_available"
,
"is_datasets_available"
,
"is_faiss_available"
,
...
...
@@ -3214,7 +3213,6 @@ if TYPE_CHECKING:
TensorType
,
add_end_docstrings
,
add_start_docstrings
,
cached_path
,
is_apex_available
,
is_datasets_available
,
is_faiss_available
,
...
...
src/transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
377cdded
...
...
@@ -38,7 +38,6 @@ from . import (
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
WEIGHTS_NAME
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
...
...
@@ -91,11 +90,10 @@ from . import (
XLMConfig
,
XLMRobertaConfig
,
XLNetConfig
,
cached_path
,
is_torch_available
,
load_pytorch_checkpoint_in_tf2_model
,
)
from
.utils
import
hf_bucket_url
,
logging
from
.utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
cached_file
,
logging
if
is_torch_available
():
...
...
@@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf(
# Initialise TF model
if
config_file
in
aws_config_map
:
config_file
=
cached_
path
(
aws_config_map
[
config_file
]
,
force_download
=
not
use_cached_models
)
config_file
=
cached_
file
(
config_file
,
CONFIG_NAME
,
force_download
=
not
use_cached_models
)
config
=
config_class
.
from_json_file
(
config_file
)
config
.
output_hidden_states
=
True
config
.
output_attentions
=
True
...
...
@@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf(
# Load weights from tf checkpoint
if
pytorch_checkpoint_path
in
aws_config_map
.
keys
():
pytorch_checkpoint_url
=
hf_bucket_url
(
pytorch_checkpoint_path
,
filename
=
WEIGHTS_NAME
)
pytorch_checkpoint_path
=
cached_path
(
pytorch_checkpoint_url
,
force_download
=
not
use_cached_models
)
pytorch_checkpoint_path
=
cached_file
(
pytorch_checkpoint_path
,
WEIGHTS_NAME
,
force_download
=
not
use_cached_models
)
# Load PyTorch checkpoint in tf2 model:
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
...
...
@@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf(
print
(
"-"
*
100
)
if
config_shortcut_name
in
aws_config_map
:
config_file
=
cached_
path
(
aws_config_map
[
config_shortcut_name
]
,
force_download
=
not
use_cached_models
)
config_file
=
cached_
file
(
config_shortcut_name
,
CONFIG_NAME
,
force_download
=
not
use_cached_models
)
else
:
config_file
=
cached_path
(
config_shortcut_name
,
force_download
=
not
use_cached_models
)
config_file
=
config_shortcut_name
if
model_shortcut_name
in
aws_model_maps
:
model_file
=
cached_
path
(
aws_model_maps
[
model_shortcut_name
]
,
force_download
=
not
use_cached_models
)
model_file
=
cached_
file
(
model_shortcut_name
,
WEIGHTS_NAME
,
force_download
=
not
use_cached_models
)
else
:
model_file
=
cached_path
(
model_shortcut_name
,
force_download
=
not
use_cached_models
)
model_file
=
model_shortcut_name
if
os
.
path
.
isfile
(
model_shortcut_name
):
model_shortcut_name
=
"converted_model"
...
...
src/transformers/dynamic_module_utils.py
View file @
377cdded
...
...
@@ -24,14 +24,7 @@ from typing import Dict, Optional, Union
from
huggingface_hub
import
HfFolder
,
model_info
from
.utils
import
(
HF_MODULES_CACHE
,
TRANSFORMERS_DYNAMIC_MODULE_NAME
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
logging
,
)
from
.utils
import
HF_MODULES_CACHE
,
TRANSFORMERS_DYNAMIC_MODULE_NAME
,
cached_file
,
is_offline_mode
,
logging
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -219,18 +212,15 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
module_file_or_url
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
module_file
)
submodule
=
"local"
else
:
module_file_or_url
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
module_file
,
revision
=
revision
,
mirror
=
None
)
submodule
=
pretrained_model_name_or_path
.
replace
(
"/"
,
os
.
path
.
sep
)
try
:
# Load from URL or cache if already cached
resolved_module_file
=
cached_path
(
module_file_or_url
,
resolved_module_file
=
cached_file
(
pretrained_model_name_or_path
,
module_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
...
...
src/transformers/file_utils.py
View file @
377cdded
...
...
@@ -69,20 +69,14 @@ from .utils import (
add_end_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
cached_path
,
cached_property
,
copy_func
,
default_cache_path
,
define_sagemaker_information
,
filename_to_url
,
get_cached_models
,
get_file_from_repo
,
get_from_cache
,
get_full_repo_name
,
get_list_of_files
,
has_file
,
hf_bucket_url
,
http_get
,
http_user_agent
,
is_apex_available
,
is_coloredlogs_available
,
...
...
@@ -94,7 +88,6 @@ from .utils import (
is_in_notebook
,
is_ipex_available
,
is_librosa_available
,
is_local_clone
,
is_offline_mode
,
is_onnx_available
,
is_pandas_available
,
...
...
@@ -105,7 +98,6 @@ from .utils import (
is_pyctcdecode_available
,
is_pytesseract_available
,
is_pytorch_quantization_available
,
is_remote_url
,
is_rjieba_available
,
is_sagemaker_dp_enabled
,
is_sagemaker_mp_enabled
,
...
...
@@ -141,5 +133,4 @@ from .utils import (
torch_only_method
,
torch_required
,
torch_version
,
url_to_filename
,
)
src/transformers/modelcard.py
View file @
377cdded
...
...
@@ -43,15 +43,10 @@ from .models.auto.modeling_auto import (
)
from
.training_args
import
ParallelMode
from
.utils
import
(
CONFIG_NAME
,
MODEL_CARD_NAME
,
TF2_WEIGHTS_NAME
,
WEIGHTS_NAME
,
cached_path
,
hf_bucket_url
,
cached_file
,
is_datasets_available
,
is_offline_mode
,
is_remote_url
,
is_tf_available
,
is_tokenizers_available
,
is_torch_available
,
...
...
@@ -153,11 +148,6 @@ class ModelCard:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
find_from_standard_name: (*optional*) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
with our standard modelcard filename. Can be used to directly feed a model/config url and access the
colocated modelcard.
return_unused_kwargs: (*optional*) bool:
- If False, then this function returns just the final model card object.
...
...
@@ -168,21 +158,15 @@ class ModelCard:
Examples:
```python
modelcard = ModelCard.from_pretrained(
"bert-base-uncased"
) # Download model card from huggingface.co and cache.
modelcard = ModelCard.from_pretrained(
"./test/saved_model/"
) # E.g. model card was saved using *save_pretrained('./test/saved_model/')*
# Download model card from huggingface.co and cache.
modelcard = ModelCard.from_pretrained("bert-base-uncased")
# Model card was saved using *save_pretrained('./test/saved_model/')*
modelcard = ModelCard.from_pretrained("./test/saved_model/")
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
```"""
# This imports every model so let's do it dynamically here.
from
transformers.models.auto.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
find_from_standard_name
=
kwargs
.
pop
(
"find_from_standard_name"
,
True
)
return_unused_kwargs
=
kwargs
.
pop
(
"return_unused_kwargs"
,
False
)
from_pipeline
=
kwargs
.
pop
(
"_from_pipeline"
,
None
)
...
...
@@ -190,31 +174,24 @@ class ModelCard:
if
from_pipeline
is
not
None
:
user_agent
[
"using_pipeline"
]
=
from_pipeline
if
pretrained_model_name_or_path
in
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
:
# For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file
=
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
model_card_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MODEL_CARD_NAME
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
model_card_file
=
pretrained_model_name_or_path
is_local
=
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
resolved_model_card_file
=
pretrained_model_name_or_path
is_local
=
True
else
:
model_card_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
MODEL_CARD_NAME
,
mirror
=
None
)
if
find_from_standard_name
or
pretrained_model_name_or_path
in
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
:
model_card_file
=
model_card_file
.
replace
(
CONFIG_NAME
,
MODEL_CARD_NAME
)
model_card_file
=
model_card_file
.
replace
(
WEIGHTS_NAME
,
MODEL_CARD_NAME
)
model_card_file
=
model_card_file
.
replace
(
TF2_WEIGHTS_NAME
,
MODEL_CARD_NAME
)
try
:
# Load from URL or cache if already cached
resolved_model_card_file
=
cached_path
(
model_card_file
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
user_agent
=
user_agent
resolved_model_card_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
=
MODEL_CARD_NAME
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
user_agent
=
user_agent
,
)
if
resolved_model_card_file
==
model_card_file
:
logger
.
info
(
f
"loading model card file
{
model_card_file
}
"
)
if
is_local
:
logger
.
info
(
f
"loading model card file
{
resolved_
model_card_file
}
"
)
else
:
logger
.
info
(
f
"loading model card file
{
model_card_file
}
from cache at
{
resolved_model_card_file
}
"
)
logger
.
info
(
f
"loading model card file
{
MODEL_CARD_NAME
}
from cache at
{
resolved_model_card_file
}
"
)
# Load model card
modelcard
=
cls
.
from_json_file
(
resolved_model_card_file
)
...
...
src/transformers/modeling_tf_utils.py
View file @
377cdded
...
...
@@ -2156,7 +2156,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
trust_remote_code
=
kwargs
.
pop
(
"trust_remote_code"
,
None
)
mirror
=
kwargs
.
pop
(
"mirror"
,
None
)
_
=
kwargs
.
pop
(
"mirror"
,
None
)
load_weight_prefix
=
kwargs
.
pop
(
"load_weight_prefix"
,
None
)
from_pipeline
=
kwargs
.
pop
(
"_from_pipeline"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
...
...
@@ -2270,7 +2270,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# message.
has_file_kwargs
=
{
"revision"
:
revision
,
"mirror"
:
mirror
,
"proxies"
:
proxies
,
"use_auth_token"
:
use_auth_token
,
}
...
...
@@ -2321,7 +2320,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
mirror
=
mirror
,
)
config
.
name_or_path
=
pretrained_model_name_or_path
...
...
src/transformers/modeling_utils.py
View file @
377cdded
...
...
@@ -1784,7 +1784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
trust_remote_code
=
kwargs
.
pop
(
"trust_remote_code"
,
None
)
mirror
=
kwargs
.
pop
(
"mirror"
,
None
)
_
=
kwargs
.
pop
(
"mirror"
,
None
)
from_pipeline
=
kwargs
.
pop
(
"_from_pipeline"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
_fast_init
=
kwargs
.
pop
(
"_fast_init"
,
True
)
...
...
@@ -1955,7 +1955,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# message.
has_file_kwargs
=
{
"revision"
:
revision
,
"mirror"
:
mirror
,
"proxies"
:
proxies
,
"use_auth_token"
:
use_auth_token
,
}
...
...
@@ -2012,7 +2011,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
mirror
=
mirror
,
subfolder
=
subfolder
,
)
...
...
src/transformers/models/rag/retrieval_rag.py
View file @
377cdded
...
...
@@ -23,7 +23,7 @@ import numpy as np
from
...tokenization_utils
import
PreTrainedTokenizer
from
...tokenization_utils_base
import
BatchEncoding
from
...utils
import
cached_
path
,
is_datasets_available
,
is_faiss_available
,
is_remote_url
,
logging
,
requires_backends
from
...utils
import
cached_
file
,
is_datasets_available
,
is_faiss_available
,
logging
,
requires_backends
from
.configuration_rag
import
RagConfig
from
.tokenization_rag
import
RagTokenizer
...
...
@@ -111,22 +111,21 @@ class LegacyIndex(Index):
self
.
_index_initialized
=
False
def
_resolve_path
(
self
,
index_path
,
filename
):
assert
os
.
path
.
isdir
(
index_path
)
or
is_remote_url
(
index_path
),
"Please specify a valid `index_path`."
archive_file
=
os
.
path
.
join
(
index_path
,
filename
)
is_local
=
os
.
path
.
isdir
(
index_path
)
try
:
# Load from URL or cache if already cached
resolved_archive_file
=
cached_
path
(
archive_fil
e
)
resolved_archive_file
=
cached_
file
(
index_path
,
filenam
e
)
except
EnvironmentError
:
msg
=
(
f
"Can't load '
{
archive_fil
e
}
'. Make sure that:
\n\n
"
f
"Can't load '
{
filenam
e
}
'. Make sure that:
\n\n
"
f
"- '
{
index_path
}
' is a correct remote path to a directory containing a file named
{
filename
}
\n\n
"
f
"- or '
{
index_path
}
' is the correct path to a directory containing a file named
{
filename
}
.
\n\n
"
)
raise
EnvironmentError
(
msg
)
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
f
"loading file
{
archive_file
}
"
)
if
is_local
:
logger
.
info
(
f
"loading file
{
resolved_
archive_file
}
"
)
else
:
logger
.
info
(
f
"loading file
{
archive_fil
e
}
from cache at
{
resolved_archive_file
}
"
)
logger
.
info
(
f
"loading file
{
filenam
e
}
from cache at
{
resolved_archive_file
}
"
)
return
resolved_archive_file
def
_load_passages
(
self
):
...
...
src/transformers/models/transfo_xl/tokenization_transfo_xl.py
View file @
377cdded
...
...
@@ -29,7 +29,7 @@ import numpy as np
from
...tokenization_utils
import
PreTrainedTokenizer
from
...utils
import
(
cached_
path
,
cached_
file
,
is_sacremoses_available
,
is_torch_available
,
logging
,
...
...
@@ -681,24 +681,21 @@ class TransfoXLCorpus(object):
Instantiate a pre-processed corpus.
"""
vocab
=
TransfoXLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
pretrained_model_name_or_path
in
PRETRAINED_CORPUS_ARCHIVE_MAP
:
corpus_file
=
PRETRAINED_CORPUS_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
corpus_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CORPUS_NAME
)
is_local
=
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
# redirect to the cache, if necessary
try
:
resolved_corpus_file
=
cached_
path
(
corpus_file
,
cache_dir
=
cache_dir
)
resolved_corpus_file
=
cached_
file
(
pretrained_model_name_or_path
,
CORPUS_NAME
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
f
"Corpus '
{
pretrained_model_name_or_path
}
' was not found in corpus list"
f
" (
{
', '
.
join
(
PRETRAINED_CORPUS_ARCHIVE_MAP
.
keys
())
}
. We assumed '
{
pretrained_model_name_or_path
}
'"
f
" was a path or url but couldn't find files
{
corpus_file
}
at this path or url."
f
" was a path or url but couldn't find files
{
CORPUS_NAME
}
at this path or url."
)
return
None
if
resolved_corpus_file
==
corpus_file
:
logger
.
info
(
f
"loading corpus file
{
corpus_file
}
"
)
if
is_local
:
logger
.
info
(
f
"loading corpus file
{
resolved_
corpus_file
}
"
)
else
:
logger
.
info
(
f
"loading corpus file
{
corpus_file
}
from cache at
{
resolved_corpus_file
}
"
)
logger
.
info
(
f
"loading corpus file
{
CORPUS_NAME
}
from cache at
{
resolved_corpus_file
}
"
)
# Instantiate tokenizer.
corpus
=
cls
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/pipelines/__init__.py
View file @
377cdded
...
...
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from
numpy
import
isin
from
huggingface_hub.file_download
import
http_get
from
..configuration_utils
import
PretrainedConfig
from
..dynamic_module_utils
import
get_class_from_dynamic_module
from
..feature_extraction_utils
import
PreTrainedFeatureExtractor
...
...
@@ -33,7 +35,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
from
..models.auto.tokenization_auto
import
TOKENIZER_MAPPING
,
AutoTokenizer
from
..tokenization_utils
import
PreTrainedTokenizer
from
..tokenization_utils_fast
import
PreTrainedTokenizerFast
from
..utils
import
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
http_get
,
is_tf_available
,
is_torch_available
,
logging
from
..utils
import
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
is_tf_available
,
is_torch_available
,
logging
from
.audio_classification
import
AudioClassificationPipeline
from
.automatic_speech_recognition
import
AutomaticSpeechRecognitionPipeline
from
.base
import
(
...
...
src/transformers/utils/__init__.py
View file @
377cdded
...
...
@@ -61,25 +61,16 @@ from .hub import (
RepositoryNotFoundError
,
RevisionNotFoundError
,
cached_file
,
cached_path
,
default_cache_path
,
define_sagemaker_information
,
filename_to_url
,
get_cached_models
,
get_file_from_repo
,
get_from_cache
,
get_full_repo_name
,
get_list_of_files
,
has_file
,
hf_bucket_url
,
http_get
,
http_user_agent
,
is_local_clone
,
is_offline_mode
,
is_remote_url
,
move_cache
,
send_example_telemetry
,
url_to_filename
,
)
from
.import_utils
import
(
ENV_VARS_TRUE_AND_AUTO_VALUES
,
...
...
src/transformers/utils/hub.py
View file @
377cdded
...
...
@@ -14,44 +14,32 @@
"""
Hub utilities: utilities related to download and cache models
"""
import
copy
import
fnmatch
import
io
import
json
import
os
import
re
import
shutil
import
subprocess
import
sys
import
tarfile
import
tempfile
import
traceback
import
warnings
from
contextlib
import
contextmanager
from
functools
import
partial
from
hashlib
import
sha256
from
pathlib
import
Path
from
typing
import
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
urllib.parse
import
urlparse
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
uuid
import
uuid4
from
zipfile
import
ZipFile
,
is_zipfile
import
huggingface_hub
import
requests
from
filelock
import
FileLock
from
huggingface_hub
import
(
CommitOperationAdd
,
HfFolder
,
create_commit
,
create_repo
,
hf_hub_download
,
list_repo_files
,
hf_hub_url
,
whoami
,
)
from
huggingface_hub.constants
import
HUGGINGFACE_HEADER_X_LINKED_ETAG
,
HUGGINGFACE_HEADER_X_REPO_COMMIT
from
huggingface_hub.utils
import
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
from
requests.exceptions
import
HTTPError
from
requests.models
import
Response
from
transformers.utils.logging
import
tqdm
from
.
import
__version__
,
logging
...
...
@@ -128,93 +116,6 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{
HUGGINGFACE_CO_EXAMPLES_TELEMETRY
=
HUGGINGFACE_CO_RESOLVE_ENDPOINT
+
"/api/telemetry/examples"
def
is_remote_url
(
url_or_filename
):
parsed
=
urlparse
(
url_or_filename
)
return
parsed
.
scheme
in
(
"http"
,
"https"
)
def
hf_bucket_url
(
model_id
:
str
,
filename
:
str
,
subfolder
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
mirror
=
None
)
->
str
:
"""
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
to Cloudfront (a Content Delivery Network, or CDN) for large files.
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
bandwidth costs).
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
"""
if
subfolder
is
not
None
:
filename
=
f
"
{
subfolder
}
/
{
filename
}
"
if
mirror
:
if
mirror
in
[
"tuna"
,
"bfsu"
]:
raise
ValueError
(
"The Tuna and BFSU mirrors are no longer available. Try removing the mirror argument."
)
legacy_format
=
"/"
not
in
model_id
if
legacy_format
:
return
f
"
{
mirror
}
/
{
model_id
}
-
{
filename
}
"
else
:
return
f
"
{
mirror
}
/
{
model_id
}
/
{
filename
}
"
if
revision
is
None
:
revision
=
"main"
return
HUGGINGFACE_CO_PREFIX
.
format
(
model_id
=
model_id
,
revision
=
revision
,
filename
=
filename
)
def
url_to_filename
(
url
:
str
,
etag
:
Optional
[
str
]
=
None
)
->
str
:
"""
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
identify it as a HDF5 file (see
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes
=
url
.
encode
(
"utf-8"
)
filename
=
sha256
(
url_bytes
).
hexdigest
()
if
etag
:
etag_bytes
=
etag
.
encode
(
"utf-8"
)
filename
+=
"."
+
sha256
(
etag_bytes
).
hexdigest
()
if
url
.
endswith
(
".h5"
):
filename
+=
".h5"
return
filename
def
filename_to_url
(
filename
,
cache_dir
=
None
):
"""
Return the url and etag (which may be `None`) stored for *filename*. Raise `EnvironmentError` if *filename* or its
stored metadata do not exist.
"""
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
if
not
os
.
path
.
exists
(
cache_path
):
raise
EnvironmentError
(
f
"file
{
cache_path
}
not found"
)
meta_path
=
cache_path
+
".json"
if
not
os
.
path
.
exists
(
meta_path
):
raise
EnvironmentError
(
f
"file
{
meta_path
}
not found"
)
with
open
(
meta_path
,
encoding
=
"utf-8"
)
as
meta_file
:
metadata
=
json
.
load
(
meta_file
)
url
=
metadata
[
"url"
]
etag
=
metadata
[
"etag"
]
return
url
,
etag
def
get_cached_models
(
cache_dir
:
Union
[
str
,
Path
]
=
None
)
->
List
[
Tuple
]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
...
...
@@ -248,108 +149,6 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
return
cached_models
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
,
user_agent
:
Union
[
Dict
,
str
,
None
]
=
None
,
extract_compressed_file
=
False
,
force_extract
=
False
,
use_auth_token
:
Union
[
bool
,
str
,
None
]
=
None
,
local_files_only
=
False
,
)
->
Optional
[
str
]:
"""
Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
then return the path
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-download the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletely received file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
will get token from ~/.huggingface.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and override the folder where it was extracted.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
if
isinstance
(
url_or_filename
,
Path
):
url_or_filename
=
str
(
url_or_filename
)
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
if
is_remote_url
(
url_or_filename
):
# URL, so get it from the cache (downloading if necessary)
output_path
=
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
user_agent
=
user_agent
,
use_auth_token
=
use_auth_token
,
local_files_only
=
local_files_only
,
)
elif
os
.
path
.
exists
(
url_or_filename
):
# File, and it exists.
output_path
=
url_or_filename
elif
urlparse
(
url_or_filename
).
scheme
==
""
:
# File, but it doesn't exist.
raise
EnvironmentError
(
f
"file
{
url_or_filename
}
not found"
)
else
:
# Something unknown
raise
ValueError
(
f
"unable to parse
{
url_or_filename
}
as a URL or as a local path"
)
if
extract_compressed_file
:
if
not
is_zipfile
(
output_path
)
and
not
tarfile
.
is_tarfile
(
output_path
):
return
output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir
,
output_file
=
os
.
path
.
split
(
output_path
)
output_extract_dir_name
=
output_file
.
replace
(
"."
,
"-"
)
+
"-extracted"
output_path_extracted
=
os
.
path
.
join
(
output_dir
,
output_extract_dir_name
)
if
os
.
path
.
isdir
(
output_path_extracted
)
and
os
.
listdir
(
output_path_extracted
)
and
not
force_extract
:
return
output_path_extracted
# Prevent parallel extractions
lock_path
=
output_path
+
".lock"
with
FileLock
(
lock_path
):
shutil
.
rmtree
(
output_path_extracted
,
ignore_errors
=
True
)
os
.
makedirs
(
output_path_extracted
)
if
is_zipfile
(
output_path
):
with
ZipFile
(
output_path
,
"r"
)
as
zip_file
:
zip_file
.
extractall
(
output_path_extracted
)
zip_file
.
close
()
elif
tarfile
.
is_tarfile
(
output_path
):
tar_file
=
tarfile
.
open
(
output_path
)
tar_file
.
extractall
(
output_path_extracted
)
tar_file
.
close
()
else
:
raise
EnvironmentError
(
f
"Archive format of
{
output_path
}
could not be identified"
)
return
output_path_extracted
return
output_path
def
define_sagemaker_information
():
try
:
instance_data
=
requests
.
get
(
os
.
environ
[
"ECS_CONTAINER_METADATA_URI"
]).
json
()
...
...
@@ -399,234 +198,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return
ua
def
_raise_for_status
(
response
:
Response
):
"""
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
"""
if
"X-Error-Code"
in
response
.
headers
:
error_code
=
response
.
headers
[
"X-Error-Code"
]
if
error_code
==
"RepoNotFound"
:
raise
RepositoryNotFoundError
(
f
"404 Client Error: Repository Not Found for url:
{
response
.
url
}
"
)
elif
error_code
==
"EntryNotFound"
:
raise
EntryNotFoundError
(
f
"404 Client Error: Entry Not Found for url:
{
response
.
url
}
"
)
elif
error_code
==
"RevisionNotFound"
:
raise
RevisionNotFoundError
(
f
"404 Client Error: Revision Not Found for url:
{
response
.
url
}
"
)
if
response
.
status_code
==
401
:
# The repo was not found and the user is not Authenticated
raise
RepositoryNotFoundError
(
f
"401 Client Error: Repository not found for url:
{
response
.
url
}
. "
"If the repo is private, make sure you are authenticated."
)
response
.
raise_for_status
()
def
http_get
(
url
:
str
,
temp_file
:
BinaryIO
,
proxies
=
None
,
resume_size
=
0
,
headers
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
file_name
:
Optional
[
str
]
=
None
,
):
"""
Download remote file. Do not gobble up errors.
"""
headers
=
copy
.
deepcopy
(
headers
)
if
resume_size
>
0
:
headers
[
"Range"
]
=
f
"bytes=
{
resume_size
}
-"
r
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
,
headers
=
headers
)
_raise_for_status
(
r
)
content_length
=
r
.
headers
.
get
(
"Content-Length"
)
total
=
resume_size
+
int
(
content_length
)
if
content_length
is
not
None
else
None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
progress
=
tqdm
(
unit
=
"B"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
total
=
total
,
initial
=
resume_size
,
desc
=
f
"Downloading
{
file_name
}
"
if
file_name
is
not
None
else
"Downloading"
,
)
for
chunk
in
r
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
progress
.
update
(
len
(
chunk
))
temp_file
.
write
(
chunk
)
progress
.
close
()
def
get_from_cache
(
url
:
str
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
,
user_agent
:
Union
[
Dict
,
str
,
None
]
=
None
,
use_auth_token
:
Union
[
bool
,
str
,
None
]
=
None
,
local_files_only
=
False
,
)
->
Optional
[
str
]:
"""
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
path to the cached file.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
headers
=
{
"user-agent"
:
http_user_agent
(
user_agent
)}
if
isinstance
(
use_auth_token
,
str
):
headers
[
"authorization"
]
=
f
"Bearer
{
use_auth_token
}
"
elif
use_auth_token
:
token
=
HfFolder
.
get_token
()
if
token
is
None
:
raise
EnvironmentError
(
"You specified use_auth_token=True, but a huggingface token was not found."
)
headers
[
"authorization"
]
=
f
"Bearer
{
token
}
"
url_to_download
=
url
etag
=
None
if
not
local_files_only
:
try
:
r
=
requests
.
head
(
url
,
headers
=
headers
,
allow_redirects
=
False
,
proxies
=
proxies
,
timeout
=
etag_timeout
)
_raise_for_status
(
r
)
etag
=
r
.
headers
.
get
(
"X-Linked-Etag"
)
or
r
.
headers
.
get
(
"ETag"
)
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if
etag
is
None
:
raise
OSError
(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
# In case of a redirect,
# save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
if
300
<=
r
.
status_code
<=
399
:
url_to_download
=
r
.
headers
[
"Location"
]
except
(
requests
.
exceptions
.
SSLError
,
requests
.
exceptions
.
ProxyError
,
RepositoryNotFoundError
,
EntryNotFoundError
,
RevisionNotFoundError
,
):
# Actually raise for those subclasses of ConnectionError
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
raise
except
(
HTTPError
,
requests
.
exceptions
.
ConnectionError
,
requests
.
exceptions
.
Timeout
):
# Otherwise, our Internet connection is down.
# etag is None
pass
filename
=
url_to_filename
(
url
,
etag
)
# get cache path to put the file
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
# etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one
if
etag
is
None
:
if
os
.
path
.
exists
(
cache_path
):
return
cache_path
else
:
matching_files
=
[
file
for
file
in
fnmatch
.
filter
(
os
.
listdir
(
cache_dir
),
filename
.
split
(
"."
)[
0
]
+
".*"
)
if
not
file
.
endswith
(
".json"
)
and
not
file
.
endswith
(
".lock"
)
]
if
len
(
matching_files
)
>
0
:
return
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
else
:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if
local_files_only
:
fname
=
url
.
split
(
"/"
)[
-
1
]
raise
EntryNotFoundError
(
f
"Cannot find the requested file (
{
fname
}
) in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
else
:
raise
ValueError
(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
# From now on, etag is not None.
if
os
.
path
.
exists
(
cache_path
)
and
not
force_download
:
return
cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path
=
cache_path
+
".lock"
with
FileLock
(
lock_path
):
# If the download just completed while the lock was activated.
if
os
.
path
.
exists
(
cache_path
)
and
not
force_download
:
# Even if returning early like here, the lock will be released.
return
cache_path
if
resume_download
:
incomplete_path
=
cache_path
+
".incomplete"
@
contextmanager
def
_resumable_file_manager
()
->
"io.BufferedWriter"
:
with
open
(
incomplete_path
,
"ab"
)
as
f
:
yield
f
temp_file_manager
=
_resumable_file_manager
if
os
.
path
.
exists
(
incomplete_path
):
resume_size
=
os
.
stat
(
incomplete_path
).
st_size
else
:
resume_size
=
0
else
:
temp_file_manager
=
partial
(
tempfile
.
NamedTemporaryFile
,
mode
=
"wb"
,
dir
=
cache_dir
,
delete
=
False
)
resume_size
=
0
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with
temp_file_manager
()
as
temp_file
:
logger
.
info
(
f
"
{
url
}
not found in cache or force_download set to True, downloading to
{
temp_file
.
name
}
"
)
# The url_to_download might be messy, so we extract the file name from the original url.
file_name
=
url
.
split
(
"/"
)[
-
1
]
http_get
(
url_to_download
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
,
headers
=
headers
,
file_name
=
file_name
,
)
logger
.
info
(
f
"storing
{
url
}
in cache at
{
cache_path
}
"
)
os
.
replace
(
temp_file
.
name
,
cache_path
)
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
umask
=
os
.
umask
(
0o666
)
os
.
umask
(
umask
)
os
.
chmod
(
cache_path
,
0o666
&
~
umask
)
logger
.
info
(
f
"creating metadata file for
{
cache_path
}
"
)
meta
=
{
"url"
:
url
,
"etag"
:
etag
}
meta_path
=
cache_path
+
".json"
with
open
(
meta_path
,
"w"
)
as
meta_file
:
json
.
dump
(
meta
,
meta_file
)
return
cache_path
def
try_to_load_from_cache
(
cache_dir
,
repo_id
,
filename
,
revision
=
None
):
"""
Explores the cache to return the latest cached file for a given revision.
...
...
@@ -919,7 +490,6 @@ def has_file(
path_or_repo
:
Union
[
str
,
os
.
PathLike
],
filename
:
str
,
revision
:
Optional
[
str
]
=
None
,
mirror
:
Optional
[
str
]
=
None
,
proxies
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
use_auth_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
):
...
...
@@ -936,7 +506,7 @@ def has_file(
if
os
.
path
.
isdir
(
path_or_repo
):
return
os
.
path
.
isfile
(
os
.
path
.
join
(
path_or_repo
,
filename
))
url
=
hf_
bucket
_url
(
path_or_repo
,
filename
=
filename
,
revision
=
revision
,
mirror
=
mirror
)
url
=
hf_
hub
_url
(
path_or_repo
,
filename
=
filename
,
revision
=
revision
)
headers
=
{
"user-agent"
:
http_user_agent
()}
if
isinstance
(
use_auth_token
,
str
):
...
...
@@ -965,89 +535,6 @@ def has_file(
return
False
def
get_list_of_files
(
path_or_repo
:
Union
[
str
,
os
.
PathLike
],
revision
:
Optional
[
str
]
=
None
,
use_auth_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
local_files_only
:
bool
=
False
,
)
->
List
[
str
]:
"""
Gets the list of files inside `path_or_repo`.
Args:
path_or_repo (`str` or `os.PathLike`):
Can be either the id of a repo on huggingface.co or a path to a *directory*.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
<Tip warning={true}>
This API is not optimized, so calling it a lot may result in connection errors.
</Tip>
Returns:
`List[str]`: The list of files available in `path_or_repo`.
"""
path_or_repo
=
str
(
path_or_repo
)
# If path_or_repo is a folder, we just return what is inside (subdirectories included).
if
os
.
path
.
isdir
(
path_or_repo
):
list_of_files
=
[]
for
path
,
dir_names
,
file_names
in
os
.
walk
(
path_or_repo
):
list_of_files
.
extend
([
os
.
path
.
join
(
path
,
f
)
for
f
in
file_names
])
return
list_of_files
# Can't grab the files if we are on offline mode.
if
is_offline_mode
()
or
local_files_only
:
return
[]
# Otherwise we grab the token and use the list_repo_files method.
if
isinstance
(
use_auth_token
,
str
):
token
=
use_auth_token
elif
use_auth_token
is
True
:
token
=
HfFolder
.
get_token
()
else
:
token
=
None
try
:
return
list_repo_files
(
path_or_repo
,
revision
=
revision
,
token
=
token
)
except
HTTPError
as
e
:
raise
ValueError
(
f
"
{
path_or_repo
}
is not a local path or a model identifier on the model Hub. Did you make a typo?"
)
from
e
def
is_local_clone
(
repo_path
,
repo_url
):
"""
Checks if the folder in `repo_path` is a local clone of `repo_url`.
"""
# First double-check that `repo_path` is a git repo
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
repo_path
,
".git"
)):
return
False
test_git
=
subprocess
.
run
(
"git branch"
.
split
(),
cwd
=
repo_path
)
if
test_git
.
returncode
!=
0
:
return
False
# Then look at its remotes
remotes
=
subprocess
.
run
(
"git remote -v"
.
split
(),
stderr
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
check
=
True
,
encoding
=
"utf-8"
,
cwd
=
repo_path
,
).
stdout
return
repo_url
in
remotes
.
split
()
class
PushToHubMixin
:
"""
A Mixin containing the functionality to push a model or tokenizer to the hub.
...
...
@@ -1310,7 +797,6 @@ def get_checkpoint_shard_files(
use_auth_token
=
None
,
user_agent
=
None
,
revision
=
None
,
mirror
=
None
,
subfolder
=
""
,
):
"""
...
...
@@ -1343,18 +829,11 @@ def get_checkpoint_shard_files(
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames
=
[]
for
shard_filename
in
shard_filenames
:
shard_url
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
shard_filename
,
revision
=
revision
,
mirror
=
mirror
,
subfolder
=
subfolder
if
len
(
subfolder
)
>
0
else
None
,
)
try
:
# Load from URL
cached_filename
=
cached_path
(
shard_url
,
cached_filename
=
cached_file
(
pretrained_model_name_or_path
,
shard_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
...
...
@@ -1362,6 +841,8 @@ def get_checkpoint_shard_files(
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
subfolder
=
subfolder
,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
...
...
tests/utils/test_file_utils.py
View file @
377cdded
...
...
@@ -26,20 +26,13 @@ import transformers
from
transformers
import
*
# noqa F406
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
from
transformers.utils
import
(
CONFIG_NAME
,
FLAX_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
WEIGHTS_NAME
,
ContextManagers
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
filename_to_url
,
find_labels
,
get_file_from_repo
,
get_from_cache
,
has_file
,
hf_bucket_url
,
is_flax_available
,
is_tf_available
,
is_torch_available
,
...
...
@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
class
GetFromCacheTests
(
unittest
.
TestCase
):
def
test_bogus_url
(
self
):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url
=
"https://bogus"
with
self
.
assertRaisesRegex
(
ValueError
,
"Connection error"
):
_
=
get_from_cache
(
url
)
def
test_file_not_found
(
self
):
# Valid revision (None) but missing file.
url
=
hf_bucket_url
(
MODEL_ID
,
filename
=
"missing.bin"
)
with
self
.
assertRaisesRegex
(
EntryNotFoundError
,
"404 Client Error"
):
_
=
get_from_cache
(
url
)
def
test_model_not_found_not_authenticated
(
self
):
# Invalid model id.
url
=
hf_bucket_url
(
"bert-base"
,
filename
=
"pytorch_model.bin"
)
with
self
.
assertRaisesRegex
(
RepositoryNotFoundError
,
"401 Client Error"
):
_
=
get_from_cache
(
url
)
@
unittest
.
skip
(
"No authentication when testing against prod"
)
def
test_model_not_found_authenticated
(
self
):
# Invalid model id.
url
=
hf_bucket_url
(
"bert-base"
,
filename
=
"pytorch_model.bin"
)
with
self
.
assertRaisesRegex
(
RepositoryNotFoundError
,
"404 Client Error"
):
_
=
get_from_cache
(
url
,
use_auth_token
=
"hf_sometoken"
)
# ^ TODO - if we decide to unskip this: use a real / functional token
def
test_revision_not_found
(
self
):
# Valid file but missing revision
url
=
hf_bucket_url
(
MODEL_ID
,
filename
=
CONFIG_NAME
,
revision
=
REVISION_ID_INVALID
)
with
self
.
assertRaisesRegex
(
RevisionNotFoundError
,
"404 Client Error"
):
_
=
get_from_cache
(
url
)
def
test_standard_object
(
self
):
url
=
hf_bucket_url
(
MODEL_ID
,
filename
=
CONFIG_NAME
,
revision
=
REVISION_ID_DEFAULT
)
filepath
=
get_from_cache
(
url
,
force_download
=
True
)
metadata
=
filename_to_url
(
filepath
)
self
.
assertEqual
(
metadata
,
(
url
,
f
'"
{
PINNED_SHA1
}
"'
))
def
test_standard_object_rev
(
self
):
# Same object, but different revision
url
=
hf_bucket_url
(
MODEL_ID
,
filename
=
CONFIG_NAME
,
revision
=
REVISION_ID_ONE_SPECIFIC_COMMIT
)
filepath
=
get_from_cache
(
url
,
force_download
=
True
)
metadata
=
filename_to_url
(
filepath
)
self
.
assertNotEqual
(
metadata
[
1
],
f
'"
{
PINNED_SHA1
}
"'
)
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def
test_lfs_object
(
self
):
url
=
hf_bucket_url
(
MODEL_ID
,
filename
=
WEIGHTS_NAME
,
revision
=
REVISION_ID_DEFAULT
)
filepath
=
get_from_cache
(
url
,
force_download
=
True
)
metadata
=
filename_to_url
(
filepath
)
self
.
assertEqual
(
metadata
,
(
url
,
f
'"
{
PINNED_SHA256
}
"'
))
def
test_has_file
(
self
):
self
.
assertTrue
(
has_file
(
"hf-internal-testing/tiny-bert-pt-only"
,
WEIGHTS_NAME
))
self
.
assertFalse
(
has_file
(
"hf-internal-testing/tiny-bert-pt-only"
,
TF2_WEIGHTS_NAME
))
...
...
utils/check_repo.py
View file @
377cdded
...
...
@@ -614,7 +614,6 @@ UNDOCUMENTED_OBJECTS = [
"absl"
,
# External module
"add_end_docstrings"
,
# Internal, should never have been in the main init.
"add_start_docstrings"
,
# Internal, should never have been in the main init.
"cached_path"
,
# Internal used for downloading models.
"convert_tf_weight_name_to_pt_weight_name"
,
# Internal used to convert model weights
"logger"
,
# Internal logger
"logging"
,
# External module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment