Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c4033ac
Unverified
Commit
7c4033ac
authored
Feb 12, 2025
by
Maximilien de Bayser
Committed by
GitHub
Feb 12, 2025
Browse files
Further reduce the HTTP calls to huggingface.co (#13107)
parent
d59def47
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
56 deletions
+79
-56
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+79
-56
No files found.
vllm/transformers_utils/config.py
View file @
7c4033ac
...
@@ -4,12 +4,14 @@ import enum
...
@@ -4,12 +4,14 @@ import enum
import
json
import
json
import
os
import
os
import
time
import
time
from
functools
import
cache
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Type
,
Union
import
huggingface_hub
import
huggingface_hub
from
huggingface_hub
import
(
file_exists
,
hf_hub_download
,
list_repo_files
,
from
huggingface_hub
import
hf_hub_download
try_to_load_from_cache
)
from
huggingface_hub
import
list_repo_files
as
hf_list_repo_files
from
huggingface_hub
import
try_to_load_from_cache
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
HFValidationError
,
LocalEntryNotFoundError
,
HFValidationError
,
LocalEntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
...
@@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
...
@@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
MISTRAL
=
"mistral"
MISTRAL
=
"mistral"
def
with_retry
(
func
:
Callable
[[],
Any
],
log_msg
:
str
,
max_retries
:
int
=
2
,
retry_delay
:
int
=
2
):
for
attempt
in
range
(
max_retries
):
try
:
return
func
()
except
Exception
as
e
:
if
attempt
==
max_retries
-
1
:
logger
.
error
(
"%s: %s"
,
log_msg
,
e
)
raise
logger
.
error
(
"%s: %s, retrying %d of %d"
,
log_msg
,
e
,
attempt
+
1
,
max_retries
)
time
.
sleep
(
retry_delay
)
retry_delay
*=
2
# @cache doesn't cache exceptions
@
cache
def
list_repo_files
(
repo_id
:
str
,
*
,
revision
:
Optional
[
str
]
=
None
,
repo_type
:
Optional
[
str
]
=
None
,
token
:
Union
[
str
,
bool
,
None
]
=
None
,
)
->
list
[
str
]:
def
lookup_files
():
try
:
return
hf_list_repo_files
(
repo_id
,
revision
=
revision
,
repo_type
=
repo_type
,
token
=
token
)
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
# Don't raise in offline mode,
# all we know is that we don't have this
# file cached.
return
[]
return
with_retry
(
lookup_files
,
"Error retrieving file list"
)
def
file_exists
(
repo_id
:
str
,
file_name
:
str
,
*
,
repo_type
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
token
:
Union
[
str
,
bool
,
None
]
=
None
,
)
->
bool
:
file_list
=
list_repo_files
(
repo_id
,
repo_type
=
repo_type
,
revision
=
revision
,
token
=
token
)
return
file_name
in
file_list
# In offline mode the result can be a false negative
def
file_or_path_exists
(
model
:
Union
[
str
,
Path
],
config_name
:
str
,
def
file_or_path_exists
(
model
:
Union
[
str
,
Path
],
config_name
:
str
,
revision
:
Optional
[
str
])
->
bool
:
revision
:
Optional
[
str
])
->
bool
:
if
Path
(
model
).
exists
():
if
Path
(
model
).
exists
():
...
@@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
...
@@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
# hf_hub. This will fail in offline mode.
# hf_hub. This will fail in offline mode.
# Call HF to check if the file exists
# Call HF to check if the file exists
# 2 retries and exponential backoff
return
file_exists
(
str
(
model
),
max_retries
=
2
retry_delay
=
2
for
attempt
in
range
(
max_retries
):
try
:
return
file_exists
(
model
,
config_name
,
config_name
,
revision
=
revision
,
revision
=
revision
,
token
=
HF_TOKEN
)
token
=
HF_TOKEN
)
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
# Don't raise in offline mode,
# all we know is that we don't have this
# file cached.
return
False
except
Exception
as
e
:
logger
.
error
(
"Error checking file existence: %s, retrying %d of %d"
,
e
,
attempt
+
1
,
max_retries
)
if
attempt
==
max_retries
-
1
:
logger
.
error
(
"Error checking file existence: %s"
,
e
)
raise
time
.
sleep
(
retry_delay
)
retry_delay
*=
2
continue
return
False
def
patch_rope_scaling
(
config
:
PretrainedConfig
)
->
None
:
def
patch_rope_scaling
(
config
:
PretrainedConfig
)
->
None
:
...
@@ -208,32 +248,7 @@ def get_config(
...
@@ -208,32 +248,7 @@ def get_config(
revision
=
revision
):
revision
=
revision
):
config_format
=
ConfigFormat
.
MISTRAL
config_format
=
ConfigFormat
.
MISTRAL
else
:
else
:
# If we're in offline mode and found no valid config format, then
raise
ValueError
(
f
"No supported config format found in
{
model
}
."
)
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
# Call HF to check if the file exists
# 2 retries and exponential backoff
max_retries
=
2
retry_delay
=
2
for
attempt
in
range
(
max_retries
):
try
:
file_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
HF_TOKEN
)
except
Exception
as
e
:
logger
.
error
(
"Error checking file existence: %s, retrying %d of %d"
,
e
,
attempt
+
1
,
max_retries
)
if
attempt
==
max_retries
:
logger
.
error
(
"Error checking file existence: %s"
,
e
)
raise
e
time
.
sleep
(
retry_delay
)
retry_delay
*=
2
raise
ValueError
(
f
"No supported config format found in
{
model
}
"
)
if
config_format
==
ConfigFormat
.
HF
:
if
config_format
==
ConfigFormat
.
HF
:
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
...
@@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
...
@@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
file_name
=
file_name
,
file_name
=
file_name
,
revision
=
revision
)
revision
=
revision
)
if
file_path
is
None
and
file_or_path_exists
(
if
file_path
is
None
:
model
=
model
,
config_name
=
file_name
,
revision
=
revision
):
try
:
try
:
hf_hub_file
=
hf_hub_download
(
model
,
file_name
,
revision
=
revision
)
hf_hub_file
=
hf_hub_download
(
model
,
file_name
,
revision
=
revision
)
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
return
None
except
(
RepositoryNotFoundError
,
RevisionNotFoundError
,
except
(
RepositoryNotFoundError
,
RevisionNotFoundError
,
EntryNotFoundError
,
LocalEntryNotFoundError
)
as
e
:
EntryNotFoundError
,
LocalEntryNotFoundError
)
as
e
:
logger
.
debug
(
"File or repository not found in hf_hub_download"
,
e
)
logger
.
debug
(
"File or repository not found in hf_hub_download"
,
e
)
...
@@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
...
@@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
return
None
return
None
@
cache
def
get_pooling_config
(
model
:
str
,
revision
:
Optional
[
str
]
=
'main'
):
def
get_pooling_config
(
model
:
str
,
revision
:
Optional
[
str
]
=
'main'
):
"""
"""
This function gets the pooling and normalize
This function gets the pooling and normalize
...
@@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
...
@@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
if
modules_dict
is
None
:
if
modules_dict
is
None
:
return
None
return
None
logger
.
info
(
"Found sentence-transformers modules configuration."
)
pooling
=
next
((
item
for
item
in
modules_dict
pooling
=
next
((
item
for
item
in
modules_dict
if
item
[
"type"
]
==
"sentence_transformers.models.Pooling"
),
if
item
[
"type"
]
==
"sentence_transformers.models.Pooling"
),
None
)
None
)
...
@@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
...
@@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
if
pooling_type_name
is
not
None
:
if
pooling_type_name
is
not
None
:
pooling_type_name
=
get_pooling_config_name
(
pooling_type_name
)
pooling_type_name
=
get_pooling_config_name
(
pooling_type_name
)
logger
.
info
(
"Found pooling configuration."
)
return
{
"pooling_type"
:
pooling_type_name
,
"normalize"
:
normalize
}
return
{
"pooling_type"
:
pooling_type_name
,
"normalize"
:
normalize
}
return
None
return
None
...
@@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
...
@@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
return
None
return
None
@
cache
def
get_sentence_transformer_tokenizer_config
(
model
:
str
,
def
get_sentence_transformer_tokenizer_config
(
model
:
str
,
revision
:
Optional
[
str
]
=
'main'
revision
:
Optional
[
str
]
=
'main'
):
):
...
@@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
...
@@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
if
not
encoder_dict
:
if
not
encoder_dict
:
return
None
return
None
logger
.
info
(
"Found sentence-transformers tokenize configuration."
)
if
all
(
k
in
encoder_dict
for
k
in
(
"max_seq_length"
,
"do_lower_case"
)):
if
all
(
k
in
encoder_dict
for
k
in
(
"max_seq_length"
,
"do_lower_case"
)):
return
encoder_dict
return
encoder_dict
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment