Unverified Commit 6cbc4d4b authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Model] Add ModelConfig class for GraniteMoeHybrid to override default...


[Model] Add ModelConfig class for GraniteMoeHybrid to override default max_seq_len_to_capture (#20923)
Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 153c6f1e
...@@ -205,6 +205,19 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): ...@@ -205,6 +205,19 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
} }
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config
config.max_seq_len_to_capture = config.max_model_len
logger.info(
"Setting max_seq_len_to_capture to %d "
"to ensure that CUDA graph capture "
"covers sequences of length up to max_model_len.",
config.max_model_len)
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod @classmethod
...@@ -297,4 +310,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -297,4 +310,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"XLMRobertaModel": JinaRobertaModelConfig, "XLMRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment