Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f7ceda34
Unverified
Commit
f7ceda34
authored
Sep 12, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 12, 2022
Browse files
Align try_to_load_from_cache with huggingface_hub (#18966)
* Align try_to_load_from_cache with huggingface_hub * Fix tests
parent
cf450b77
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
26 deletions
+37
-26
src/transformers/utils/hub.py
src/transformers/utils/hub.py
+37
-26
No files found.
src/transformers/utils/hub.py
View file @
f7ceda34
...
@@ -222,18 +222,27 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
...
@@ -222,18 +222,27 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return
commit_hash
if
REGEX_COMMIT_HASH
.
match
(
commit_hash
)
else
None
return
commit_hash
if
REGEX_COMMIT_HASH
.
match
(
commit_hash
)
else
None
def
try_to_load_from_cache
(
cache_dir
,
repo_id
,
filename
,
revision
=
None
,
commit_hash
=
None
):
def
try_to_load_from_cache
(
repo_id
:
str
,
filename
:
str
,
cache_dir
:
Union
[
str
,
Path
,
None
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
Optional
[
str
]:
"""
"""
Explores the cache to return the latest cached file for a given revision.
Explores the cache to return the latest cached file for a given revision if found.
This function will not raise any exception if the file in not cached.
Args:
Args:
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
cache_dir (`str` or `os.PathLike`):
repo_id (`str`): The ID of the repo on huggingface.co.
The folder where the cached files lie.
filename (`str`): The filename to look for inside `repo_id`.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
provided either.
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
Returns:
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
`Optional[str]` or `_CACHED_NO_EXIST`:
...
@@ -242,36 +251,36 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
...
@@ -242,36 +251,36 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
cached.
"""
"""
if
commit_hash
is
not
None
and
revision
is
not
None
:
if
revision
is
None
:
raise
ValueError
(
"`commit_hash` and `revision` are mutually exclusive, pick one only."
)
if
revision
is
None
and
commit_hash
is
None
:
revision
=
"main"
revision
=
"main"
model_id
=
repo_id
.
replace
(
"/"
,
"--"
)
if
cache_dir
is
None
:
model_cache
=
os
.
path
.
join
(
cache_dir
,
f
"models--
{
model_id
}
"
)
cache_dir
=
TRANSFORMERS_CACHE
if
not
os
.
path
.
isdir
(
model_cache
):
object_id
=
repo_id
.
replace
(
"/"
,
"--"
)
repo_cache
=
os
.
path
.
join
(
cache_dir
,
f
"models--
{
object_id
}
"
)
if
not
os
.
path
.
isdir
(
repo_cache
):
# No cache for this model
# No cache for this model
return
None
return
None
for
subfolder
in
[
"refs"
,
"snapshots"
]:
for
subfolder
in
[
"refs"
,
"snapshots"
]:
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
model
_cache
,
subfolder
)):
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
repo
_cache
,
subfolder
)):
return
None
return
None
if
commit_hash
is
None
:
# Resolve refs (for instance to convert main to the associated commit sha)
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs
=
os
.
listdir
(
os
.
path
.
join
(
model
_cache
,
"refs"
))
cached_refs
=
os
.
listdir
(
os
.
path
.
join
(
repo
_cache
,
"refs"
))
if
revision
in
cached_refs
:
if
revision
in
cached_refs
:
with
open
(
os
.
path
.
join
(
model
_cache
,
"refs"
,
revision
))
as
f
:
with
open
(
os
.
path
.
join
(
repo
_cache
,
"refs"
,
revision
))
as
f
:
commit_hash
=
f
.
read
()
revision
=
f
.
read
()
if
os
.
path
.
isfile
(
os
.
path
.
join
(
model
_cache
,
".no_exist"
,
commit_hash
,
filename
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
repo
_cache
,
".no_exist"
,
revision
,
filename
)):
return
_CACHED_NO_EXIST
return
_CACHED_NO_EXIST
cached_shas
=
os
.
listdir
(
os
.
path
.
join
(
model
_cache
,
"snapshots"
))
cached_shas
=
os
.
listdir
(
os
.
path
.
join
(
repo
_cache
,
"snapshots"
))
if
commit_hash
not
in
cached_shas
:
if
revision
not
in
cached_shas
:
# No cache for this revision and we won't try to return a random revision
# No cache for this revision and we won't try to return a random revision
return
None
return
None
cached_file
=
os
.
path
.
join
(
model
_cache
,
"snapshots"
,
commit_hash
,
filename
)
cached_file
=
os
.
path
.
join
(
repo
_cache
,
"snapshots"
,
revision
,
filename
)
return
cached_file
if
os
.
path
.
isfile
(
cached_file
)
else
None
return
cached_file
if
os
.
path
.
isfile
(
cached_file
)
else
None
...
@@ -375,7 +384,9 @@ def cached_file(
...
@@ -375,7 +384,9 @@ def cached_file(
if
_commit_hash
is
not
None
:
if
_commit_hash
is
not
None
:
# If the file is cached under that commit hash, we return it directly.
# If the file is cached under that commit hash, we return it directly.
resolved_file
=
try_to_load_from_cache
(
cache_dir
,
path_or_repo_id
,
full_filename
,
commit_hash
=
_commit_hash
)
resolved_file
=
try_to_load_from_cache
(
path_or_repo_id
,
full_filename
,
cache_dir
=
cache_dir
,
revision
=
_commit_hash
)
if
resolved_file
is
not
None
:
if
resolved_file
is
not
None
:
if
resolved_file
is
not
_CACHED_NO_EXIST
:
if
resolved_file
is
not
_CACHED_NO_EXIST
:
return
resolved_file
return
resolved_file
...
@@ -416,7 +427,7 @@ def cached_file(
...
@@ -416,7 +427,7 @@ def cached_file(
)
)
except
LocalEntryNotFoundError
:
except
LocalEntryNotFoundError
:
# We try to see if we have a cached version (not up to date):
# We try to see if we have a cached version (not up to date):
resolved_file
=
try_to_load_from_cache
(
cache_dir
,
path_or_repo_id
,
full_filename
,
revision
=
revision
)
resolved_file
=
try_to_load_from_cache
(
path_or_repo_id
,
full_filename
,
cache_dir
=
cache_dir
,
revision
=
revision
)
if
resolved_file
is
not
None
:
if
resolved_file
is
not
None
:
return
resolved_file
return
resolved_file
if
not
_raise_exceptions_for_missing_entries
or
not
_raise_exceptions_for_connection_errors
:
if
not
_raise_exceptions_for_missing_entries
or
not
_raise_exceptions_for_connection_errors
:
...
@@ -438,7 +449,7 @@ def cached_file(
...
@@ -438,7 +449,7 @@ def cached_file(
)
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
# First we try to see if we have a cached version (not up to date):
# First we try to see if we have a cached version (not up to date):
resolved_file
=
try_to_load_from_cache
(
cache_dir
,
path_or_repo_id
,
full_filename
,
revision
=
revision
)
resolved_file
=
try_to_load_from_cache
(
path_or_repo_id
,
full_filename
,
cache_dir
=
cache_dir
,
revision
=
revision
)
if
resolved_file
is
not
None
:
if
resolved_file
is
not
None
:
return
resolved_file
return
resolved_file
if
not
_raise_exceptions_for_connection_errors
:
if
not
_raise_exceptions_for_connection_errors
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment