Unverified Commit 54b0578a authored by Javier De Jesus's avatar Javier De Jesus Committed by GitHub
Browse files

[Bugfix] Pass hf_token through config loading paths for gated model support (#37920)


Signed-off-by: default avatarjavierdejesusda <javier.dejesusj9@gmail.com>
parent 89f572db
...@@ -488,6 +488,7 @@ class ModelConfig: # type: ignore[misc] ...@@ -488,6 +488,7 @@ class ModelConfig: # type: ignore[misc]
self.config_format, self.config_format,
hf_overrides_kw=hf_overrides_kw, hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn, hf_overrides_fn=hf_overrides_fn,
token=self.hf_token,
) )
hf_config = maybe_patch_hf_config_from_gguf( hf_config = maybe_patch_hf_config_from_gguf(
self.model, self.model,
...@@ -1341,12 +1342,14 @@ class ModelConfig: # type: ignore[misc] ...@@ -1341,12 +1342,14 @@ class ModelConfig: # type: ignore[misc]
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
revision=self.revision, revision=self.revision,
config_format=self.config_format, config_format=self.config_format,
hf_token=self.hf_token,
) )
else: else:
config = try_get_generation_config( config = try_get_generation_config(
self.generation_config, self.generation_config,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
config_format=self.config_format, config_format=self.config_format,
hf_token=self.hf_token,
) )
if config is None: if config is None:
......
...@@ -1530,6 +1530,7 @@ class EngineArgs: ...@@ -1530,6 +1530,7 @@ class EngineArgs:
revision=self.revision, revision=self.revision,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config, vllm_speculative_config=self.speculative_config,
hf_token=self.hf_token,
) )
) )
......
...@@ -542,6 +542,7 @@ def maybe_override_with_speculators( ...@@ -542,6 +542,7 @@ def maybe_override_with_speculators(
trust_remote_code: bool, trust_remote_code: bool,
revision: str | None = None, revision: str | None = None,
vllm_speculative_config: dict[str, Any] | None = None, vllm_speculative_config: dict[str, Any] | None = None,
hf_token: bool | str | None = None,
**kwargs, **kwargs,
) -> tuple[str, str | None, dict[str, Any] | None]: ) -> tuple[str, str | None, dict[str, Any] | None]:
""" """
...@@ -556,6 +557,7 @@ def maybe_override_with_speculators( ...@@ -556,6 +557,7 @@ def maybe_override_with_speculators(
trust_remote_code: Whether to trust remote code trust_remote_code: Whether to trust remote code
revision: Model revision revision: Model revision
vllm_speculative_config: Existing vLLM speculative config vllm_speculative_config: Existing vLLM speculative config
hf_token: HuggingFace token for authenticated model access
Returns: Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config) Tuple of (resolved_model, resolved_tokenizer, speculative_config)
...@@ -572,6 +574,7 @@ def maybe_override_with_speculators( ...@@ -572,6 +574,7 @@ def maybe_override_with_speculators(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model if gguf_model_repo is None else gguf_model_repo, model if gguf_model_repo is None else gguf_model_repo,
revision=revision, revision=revision,
token=hf_token,
**without_trust_remote_code(kwargs), **without_trust_remote_code(kwargs),
) )
speculators_config = config_dict.get("speculators_config") speculators_config = config_dict.get("speculators_config")
...@@ -1054,6 +1057,7 @@ def try_get_generation_config( ...@@ -1054,6 +1057,7 @@ def try_get_generation_config(
trust_remote_code: bool, trust_remote_code: bool,
revision: str | None = None, revision: str | None = None,
config_format: str | ConfigFormat = "auto", config_format: str | ConfigFormat = "auto",
hf_token: bool | str | None = None,
) -> GenerationConfig | None: ) -> GenerationConfig | None:
# GGUF files don't have generation_config.json - their config is embedded # GGUF files don't have generation_config.json - their config is embedded
# in the file header. Skip all filesystem lookups to avoid re-reading the # in the file header. Skip all filesystem lookups to avoid re-reading the
...@@ -1066,6 +1070,7 @@ def try_get_generation_config( ...@@ -1066,6 +1070,7 @@ def try_get_generation_config(
return GenerationConfig.from_pretrained( return GenerationConfig.from_pretrained(
model, model,
revision=revision, revision=revision,
token=hf_token,
) )
except OSError: # Not found except OSError: # Not found
try: try:
...@@ -1074,6 +1079,7 @@ def try_get_generation_config( ...@@ -1074,6 +1079,7 @@ def try_get_generation_config(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
config_format=config_format, config_format=config_format,
token=hf_token,
) )
return GenerationConfig.from_model_config(config) return GenerationConfig.from_model_config(config)
except OSError: # Not found except OSError: # Not found
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment