Unverified Commit e47433b3 authored by acisseJZhong's avatar acisseJZhong Committed by GitHub
Browse files

[BugFix] Pass config_format via try_get_generation_config (#25912)

parent 23194d83
...@@ -1334,11 +1334,13 @@ class ModelConfig: ...@@ -1334,11 +1334,13 @@ class ModelConfig:
self.hf_config_path or self.model, self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
revision=self.revision, revision=self.revision,
config_format=self.config_format,
) )
else: else:
config = try_get_generation_config( config = try_get_generation_config(
self.generation_config, self.generation_config,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
config_format=self.config_format,
) )
if config is None: if config is None:
......
...@@ -949,6 +949,7 @@ def try_get_generation_config( ...@@ -949,6 +949,7 @@ def try_get_generation_config(
model: str, model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
config_format: Union[str, ConfigFormat] = "auto",
) -> Optional[GenerationConfig]: ) -> Optional[GenerationConfig]:
try: try:
return GenerationConfig.from_pretrained( return GenerationConfig.from_pretrained(
...@@ -961,6 +962,7 @@ def try_get_generation_config( ...@@ -961,6 +962,7 @@ def try_get_generation_config(
model, model,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
config_format=config_format,
) )
return GenerationConfig.from_model_config(config) return GenerationConfig.from_model_config(config)
except OSError: # Not found except OSError: # Not found
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment