Unverified Commit 24a163ed authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Cleanup some huggingface_hub-related stuff (#32788)

parent 378385b9
...@@ -5,12 +5,7 @@ import os ...@@ -5,12 +5,7 @@ import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import huggingface_hub import huggingface_hub
from huggingface_hub.utils import ( from huggingface_hub.utils import HfHubHTTPError, HFValidationError
EntryNotFoundError,
HfHubHTTPError,
HFValidationError,
RepositoryNotFoundError,
)
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -243,12 +238,7 @@ def get_adapter_absolute_path(lora_path: str) -> str: ...@@ -243,12 +238,7 @@ def get_adapter_absolute_path(lora_path: str) -> str:
# If the path does not exist locally, assume it's a Hugging Face repo. # If the path does not exist locally, assume it's a Hugging Face repo.
try: try:
local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path) local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
except ( except (HfHubHTTPError, HFValidationError):
HfHubHTTPError,
RepositoryNotFoundError,
EntryNotFoundError,
HFValidationError,
):
# Handle errors that may occur during the download # Handle errors that may occur during the download
# Return original path instead of throwing error here # Return original path instead of throwing error here
logger.exception("Error downloading the HuggingFace model") logger.exception("Error downloading the HuggingFace model")
......
...@@ -33,7 +33,6 @@ from .gguf_utils import ( ...@@ -33,7 +33,6 @@ from .gguf_utils import (
split_remote_gguf, split_remote_gguf,
) )
from .repo_utils import ( from .repo_utils import (
_get_hf_token,
file_or_path_exists, file_or_path_exists,
get_hf_file_to_dict, get_hf_file_to_dict,
list_repo_files, list_repo_files,
...@@ -135,7 +134,6 @@ class HFConfigParser(ConfigParserBase): ...@@ -135,7 +134,6 @@ class HFConfigParser(ConfigParserBase):
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs, **kwargs,
) )
# Use custom model class if it's in our registry # Use custom model class if it's in our registry
...@@ -157,7 +155,6 @@ class HFConfigParser(ConfigParserBase): ...@@ -157,7 +155,6 @@ class HFConfigParser(ConfigParserBase):
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs, **kwargs,
) )
else: else:
...@@ -168,7 +165,6 @@ class HFConfigParser(ConfigParserBase): ...@@ -168,7 +165,6 @@ class HFConfigParser(ConfigParserBase):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=_get_hf_token(),
**kwargs, **kwargs,
) )
except ValueError as e: except ValueError as e:
...@@ -218,7 +214,6 @@ class MistralConfigParser(ConfigParserBase): ...@@ -218,7 +214,6 @@ class MistralConfigParser(ConfigParserBase):
model, model,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=_get_hf_token(),
**kwargs, **kwargs,
) )
except OSError: # Not found except OSError: # Not found
...@@ -529,7 +524,6 @@ def maybe_override_with_speculators( ...@@ -529,7 +524,6 @@ def maybe_override_with_speculators(
model if gguf_model_repo is None else gguf_model_repo, model if gguf_model_repo is None else gguf_model_repo,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs, **kwargs,
) )
speculators_config = config_dict.get("speculators_config") speculators_config = config_dict.get("speculators_config")
...@@ -871,9 +865,7 @@ def get_sentence_transformer_tokenizer_config( ...@@ -871,9 +865,7 @@ def get_sentence_transformer_tokenizer_config(
if not encoder_dict and not Path(model).is_absolute(): if not encoder_dict and not Path(model).is_absolute():
try: try:
# If model is on HuggingfaceHub, get the repo files # If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files( repo_files = list_repo_files(model, revision=revision)
model, revision=revision, token=_get_hf_token()
)
except Exception: except Exception:
repo_files = [] repo_files = []
...@@ -1042,10 +1034,7 @@ def try_get_safetensors_metadata( ...@@ -1042,10 +1034,7 @@ def try_get_safetensors_metadata(
revision: str | None = None, revision: str | None = None,
): ):
get_safetensors_metadata_partial = partial( get_safetensors_metadata_partial = partial(
get_safetensors_metadata, get_safetensors_metadata, model, revision=revision
model,
revision=revision,
token=_get_hf_token(),
) )
try: try:
......
...@@ -12,10 +12,7 @@ from pathlib import Path ...@@ -12,10 +12,7 @@ from pathlib import Path
from typing import TypeVar from typing import TypeVar
import huggingface_hub import huggingface_hub
from huggingface_hub import ( from huggingface_hub import hf_hub_download, try_to_load_from_cache
hf_hub_download,
try_to_load_from_cache,
)
from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import list_repo_files as hf_list_repo_files
from huggingface_hub.utils import ( from huggingface_hub.utils import (
EntryNotFoundError, EntryNotFoundError,
...@@ -31,21 +28,6 @@ from vllm.logger import init_logger ...@@ -31,21 +28,6 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_hf_token() -> str | None:
"""
Get the HuggingFace token from environment variable.
Returns None if the token is not set, is an empty string,
or contains only whitespace.
This follows the same pattern as huggingface_hub library which
treats empty string tokens as None to avoid authentication errors.
"""
token = os.getenv("HF_TOKEN")
if token and token.strip():
return token
return None
_R = TypeVar("_R") _R = TypeVar("_R")
...@@ -153,6 +135,8 @@ def file_exists( ...@@ -153,6 +135,8 @@ def file_exists(
revision: str | None = None, revision: str | None = None,
token: str | bool | None = None, token: str | bool | None = None,
) -> bool: ) -> bool:
# `list_repo_files` is cached and retried on error, so this is more efficient than
# huggingface_hub.file_exists default implementation when looking for multiple files
file_list = list_repo_files( file_list = list_repo_files(
repo_id, repo_type=repo_type, revision=revision, token=token repo_id, repo_type=repo_type, revision=revision, token=token
) )
...@@ -178,9 +162,7 @@ def file_or_path_exists( ...@@ -178,9 +162,7 @@ def file_or_path_exists(
# hf_hub. This will fail in offline mode. # hf_hub. This will fail in offline mode.
# Call HF to check if the file exists # Call HF to check if the file exists
return file_exists( return file_exists(str(model), config_name, revision=revision)
str(model), config_name, revision=revision, token=_get_hf_token()
)
def get_model_path(model: str | Path, revision: str | None = None): def get_model_path(model: str | Path, revision: str | None = None):
...@@ -209,9 +191,7 @@ def get_hf_file_bytes( ...@@ -209,9 +191,7 @@ def get_hf_file_bytes(
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None: if file_path is None:
hf_hub_file = hf_hub_download( hf_hub_file = hf_hub_download(model, file_name, revision=revision)
model, file_name, revision=revision, token=_get_hf_token()
)
file_path = Path(hf_hub_file) file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file(): if file_path is not None and file_path.is_file():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment