Unverified Commit da1e7311 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] use model arch converter for bidi models identification (#40701)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 01cb41dc
......@@ -1197,22 +1197,9 @@ class ModelConfig:
def is_deepseek_mla(self) -> bool:
return self.model_arch_config.is_deepseek_mla
@cached_property
@property
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
if hasattr(self.hf_config, "is_mm_prefix_lm"):
return bool(self.hf_config.is_mm_prefix_lm)
# fallback to list of known models
MM_PREFIX_LM_MODELS = (
"bagel",
"gemma3",
"molmo2",
"paligemma",
"umm",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
return self.model_arch_config.is_mm_prefix_lm
def get_head_size(self) -> int:
return self.model_arch_config.head_size
......
......@@ -53,5 +53,8 @@ class ModelArchitectureConfig:
is_deepseek_mla: bool
"""Whether the model is a DeepSeek MLA model."""
is_mm_prefix_lm: bool
"""Whether the model uses image bidirectional attention."""
derived_max_model_len_and_key: tuple[float, str | None]
"""Derived maximum model length and key from the hf config."""
......@@ -250,6 +250,22 @@ class ModelArchConfigConvertorBase:
)
return False
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
if hasattr(self.hf_config, "is_mm_prefix_lm"):
return bool(self.hf_config.is_mm_prefix_lm)
# fallback to list of known models
MM_PREFIX_LM_MODELS = (
"bagel",
"gemma3",
"molmo2",
"paligemma",
"umm",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
def derive_max_model_len_and_key(self) -> tuple[float, str | None]:
derived_max_model_len = float("inf")
possible_keys = [
......@@ -299,6 +315,7 @@ class ModelArchConfigConvertorBase:
num_experts=self.get_num_experts(),
quantization_config=self.get_quantization_config(),
is_deepseek_mla=self.is_deepseek_mla(),
is_mm_prefix_lm=self.is_mm_prefix_lm(),
derived_max_model_len_and_key=self.derive_max_model_len_and_key(),
)
......@@ -451,6 +468,12 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def is_mm_prefix_lm(self) -> bool:
return (
getattr(self.hf_text_config, "use_bidirectional_attention", None)
== "vision"
)
def get_head_size(self) -> int:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment