Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f7ceda34
Unverified
Commit
f7ceda34
authored
Sep 12, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 12, 2022
Browse files
Align try_to_load_from_cache with huggingface_hub (#18966)
* Align try_to_load_from_cache with huggingface_hub * Fix tests
parent
cf450b77
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
26 deletions
+37
-26
src/transformers/utils/hub.py
src/transformers/utils/hub.py
+37
-26
No files found.
src/transformers/utils/hub.py
View file @
f7ceda34
...
@@ -222,18 +222,27 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
...
@@ -222,18 +222,27 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return
commit_hash
if
REGEX_COMMIT_HASH
.
match
(
commit_hash
)
else
None
return
commit_hash
if
REGEX_COMMIT_HASH
.
match
(
commit_hash
)
else
None
def
try_to_load_from_cache
(
cache_dir
,
repo_id
,
filename
,
revision
=
None
,
commit_hash
=
None
):
def
try_to_load_from_cache
(
repo_id
:
str
,
filename
:
str
,
cache_dir
:
Union
[
str
,
Path
,
None
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
Optional
[
str
]:
"""
"""
Explores the cache to return the latest cached file for a given revision.
Explores the cache to return the latest cached file for a given revision if found.
This function will not raise any exception if the file in not cached.
Args:
Args:
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
cache_dir (`str` or `os.PathLike`):
repo_id (`str`): The ID of the repo on huggingface.co.
The folder where the cached files lie.
filename (`str`): The filename to look for inside `repo_id`.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
provided either.
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
Returns:
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
`Optional[str]` or `_CACHED_NO_EXIST`:
...
@@ -242,36 +251,36 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
...
@@ -242,36 +251,36 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
cached.
"""
"""
if
commit_hash
is
not
None
and
revision
is
not
None
:
if
revision
is
None
:
raise
ValueError
(
"`commit_hash` and `revision` are mutually exclusive, pick one only."
)
if
revision
is
None
and
commit_hash
is
None
:
revision
=
"main"
revision
=
"main"
model_id
=
repo_id
.
replace
(
"/"
,
"--"
)
if
cache_dir
is
None
:
model_cache
=
os
.
path
.
join
(
cache_dir
,
f
"models--
{
model_id
}
"
)
cache_dir
=
TRANSFORMERS_CACHE
if
not
os
.
path
.
isdir
(
model_cache
):
object_id
=
repo_id
.
replace
(
"/"
,
"--"
)
repo_cache
=
os
.
path
.
join
(
cache_dir
,
f
"models--
{
object_id
}
"
)
if
not
os
.
path
.
isdir
(
repo_cache
):
# No cache for this model
# No cache for this model
return
None
return
None
for
subfolder
in
[
"refs"
,
"snapshots"
]:
for
subfolder
in
[
"refs"
,
"snapshots"
]:
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
model
_cache
,
subfolder
)):
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
repo
_cache
,
subfolder
)):
return
None
return
None
if
commit_hash
is
None
:
# Resolve refs (for instance to convert main to the associated commit sha)
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs
=
os
.
listdir
(
os
.
path
.
join
(
repo_cache
,
"refs"
))
cached_refs
=
os
.
listdir
(
os
.
path
.
join
(
model_cache
,
"refs"
))
if
revision
in
cached_refs
:
if
revision
in
cached_refs
:
with
open
(
os
.
path
.
join
(
repo_cache
,
"refs"
,
revision
))
as
f
:
with
open
(
os
.
path
.
join
(
model_cache
,
"refs"
,
revision
))
as
f
:
revision
=
f
.
read
()
commit_hash
=
f
.
read
()
if
os
.
path
.
isfile
(
os
.
path
.
join
(
model
_cache
,
".no_exist"
,
commit_hash
,
filename
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
repo
_cache
,
".no_exist"
,
revision
,
filename
)):
return
_CACHED_NO_EXIST
return
_CACHED_NO_EXIST
cached_shas
=
os
.
listdir
(
os
.
path
.
join
(
model
_cache
,
"snapshots"
))
cached_shas
=
os
.
listdir
(
os
.
path
.
join
(
repo
_cache
,
"snapshots"
))
if
commit_hash
not
in
cached_shas
:
if
revision
not
in
cached_shas
:
# No cache for this revision and we won't try to return a random revision
# No cache for this revision and we won't try to return a random revision
return
None
return
None
cached_file
=
os
.
path
.
join
(
model
_cache
,
"snapshots"
,
commit_hash
,
filename
)
cached_file
=
os
.
path
.
join
(
repo
_cache
,
"snapshots"
,
revision
,
filename
)
return
cached_file
if
os
.
path
.
isfile
(
cached_file
)
else
None
return
cached_file
if
os
.
path
.
isfile
(
cached_file
)
else
None
...
@@ -375,7 +384,9 @@ def cached_file(
...
@@ -375,7 +384,9 @@ def cached_file(
if
_commit_hash
is
not
None
:
if
_commit_hash
is
not
None
:
# If the file is cached under that commit hash, we return it directly.
# If the file is cached under that commit hash, we return it directly.
resolved_file
=
try_to_load_from_cache
(
cache_dir
,
path_or_repo_id
,
full_filename
,
commit_hash
=
_commit_hash
)
resolved_file
=
try_to_load_from_cache
(
path_or_repo_id
,
full_filename
,
cache_dir
=
cache_dir
,
revision
=
_commit_hash
)
if
resolved_file
is
not
None
:
if
resolved_file
is
not
None
:
if
resolved_file
is
not
_CACHED_NO_EXIST
:
if
resolved_file
is
not
_CACHED_NO_EXIST
:
return
resolved_file
return
resolved_file
...
@@ -416,7 +427,7 @@ def cached_file(
...
@@ -416,7 +427,7 @@ def cached_file(
)
)
except
LocalEntryNotFoundError
:
except
LocalEntryNotFoundError
:
# We try to see if we have a cached version (not up to date):
# We try to see if we have a cached version (not up to date):
resolved_file
=
try_to_load_from_cache
(
cache_dir
,
path_or_repo_id
,
full_filename
,
revision
=
revision
)
resolved_file
=
try_to_load_from_cache
(
path_or_repo_id
,
full_filename
,
cache_dir
=
cache_dir
,
revision
=
revision
)
if
resolved_file
is
not
None
:
if
resolved_file
is
not
None
:
return
resolved_file
return
resolved_file
if
not
_raise_exceptions_for_missing_entries
or
not
_raise_exceptions_for_connection_errors
:
if
not
_raise_exceptions_for_missing_entries
or
not
_raise_exceptions_for_connection_errors
:
...
@@ -438,7 +449,7 @@ def cached_file(
...
@@ -438,7 +449,7 @@ def cached_file(
)
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
# First we try to see if we have a cached version (not up to date):
# First we try to see if we have a cached version (not up to date):
resolved_file
=
try_to_load_from_cache
(
cache_dir
,
path_or_repo_id
,
full_filename
,
revision
=
revision
)
resolved_file
=
try_to_load_from_cache
(
path_or_repo_id
,
full_filename
,
cache_dir
=
cache_dir
,
revision
=
revision
)
if
resolved_file
is
not
None
:
if
resolved_file
is
not
None
:
return
resolved_file
return
resolved_file
if
not
_raise_exceptions_for_connection_errors
:
if
not
_raise_exceptions_for_connection_errors
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment