Unverified Commit d88f28da authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix `hf_override_fn` when it modifies `model_type` (#35200)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 106ff69c
...@@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase): ...@@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase):
) )
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY # Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
model_type = hf_overrides.get("model_type", model_type) if isinstance(hf_overrides, dict) and "model_type" in hf_overrides:
model_type = hf_overrides["model_type"]
elif callable(hf_overrides):
# If hf_overrides doesn't modify model_type, it will be passed straight
# through and remain unchanged by this elif block
dummy_model_type = f"dummy_{model_type}"
dummy_kwargs = dict(architectures=[""], model_type=dummy_model_type)
dummy_config = PretrainedConfig(**dummy_kwargs)
dummy_model_type = hf_overrides(dummy_config).model_type
model_type = dummy_model_type.removeprefix("dummy_")
if model_type in _CONFIG_REGISTRY: if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type] config_class = _CONFIG_REGISTRY[model_type]
...@@ -634,7 +643,7 @@ def get_config( ...@@ -634,7 +643,7 @@ def get_config(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
hf_overrides=hf_overrides_kw, hf_overrides=hf_overrides_kw or hf_overrides_fn,
**kwargs, **kwargs,
) )
......
...@@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase: ...@@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase:
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head return self.hf_text_config.hidden_size_per_head
if (total_num_attention_heads := self.get_total_num_attention_heads()) == 0:
return 0
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return ( return self.get_hidden_size() // total_num_attention_heads
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
)
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
attributes = [ attributes = [
...@@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase: ...@@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase:
] ]
# For non-grouped-query attention models, the number of KV heads is # For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads. # equal to the number of attention heads.
default_factory = lambda: self.hf_text_config.num_attention_heads default_factory = self.get_total_num_attention_heads
return getattr_iter( return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory self.hf_text_config, attributes, default_factory=default_factory
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment