Unverified Commit 0eee877f authored by Xingyu Liu's avatar Xingyu Liu Committed by GitHub
Browse files

[Core] Parse vLLM engine required fields from hf_config to model_arch_config (#28454)


Signed-off-by: default avatarXingyu Liu <charlotteliu12x@gmail.com>
Signed-off-by: default avatarXingyu Liu <38244988+charlotte12l@users.noreply.github.com>
parent a0e9ee83
{
"state-spaces/mamba-130m-hf": {
"architectures": [
"MambaForCausalLM"
],
"model_type": "mamba",
"text_model_type": "mamba",
"hidden_size": 768,
"total_num_hidden_layers": 24,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 50280,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.float32"
},
"mistralai/Mamba-Codestral-7B-v0.1": {
"architectures": [
"Mamba2ForCausalLM"
],
"model_type": "mamba",
"text_model_type": "mamba",
"hidden_size": 4096,
"total_num_hidden_layers": 64,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 32768,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"architectures": [
"Terratorch"
],
"model_type": "timm_wrapper",
"text_model_type": "timm_wrapper",
"hidden_size": 0,
"total_num_hidden_layers": 0,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 0,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": true,
"dtype": "torch.float32"
},
"tiiuae/falcon-mamba-7b-instruct": {
"architectures": [
"FalconMambaForCausalLM"
],
"model_type": "falcon_mamba",
"text_model_type": "falcon_mamba",
"hidden_size": 4096,
"total_num_hidden_layers": 64,
"total_num_attention_heads": 0,
"head_size": 0,
"vocab_size": 65024,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"Zyphra/Zamba2-7B-instruct": {
"architectures": [
"Zamba2ForCausalLM"
],
"model_type": "zamba2",
"text_model_type": "zamba2",
"hidden_size": 3584,
"total_num_hidden_layers": 81,
"total_num_attention_heads": 32,
"head_size": 224,
"vocab_size": 32000,
"total_num_kv_heads": 32,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"mosaicml/mpt-7b": {
"architectures": [
"MPTForCausalLM"
],
"model_type": "mpt",
"text_model_type": "mpt",
"hidden_size": 4096,
"total_num_hidden_layers": 32,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 50432,
"total_num_kv_heads": 32,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"databricks/dbrx-instruct": {
"architectures": [
"DbrxForCausalLM"
],
"model_type": "dbrx",
"text_model_type": "dbrx",
"hidden_size": 6144,
"total_num_hidden_layers": 40,
"total_num_attention_heads": 48,
"head_size": 128,
"vocab_size": 100352,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"tiiuae/falcon-7b": {
"architectures": [
"FalconForCausalLM"
],
"model_type": "falcon",
"text_model_type": "falcon",
"hidden_size": 4544,
"total_num_hidden_layers": 32,
"total_num_attention_heads": 71,
"head_size": 64,
"vocab_size": 65024,
"total_num_kv_heads": 1,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"tiiuae/falcon-40b": {
"architectures": [
"FalconForCausalLM"
],
"model_type": "falcon",
"text_model_type": "falcon",
"hidden_size": 8192,
"total_num_hidden_layers": 60,
"total_num_attention_heads": 128,
"head_size": 64,
"vocab_size": 65024,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"luccafong/deepseek_mtp_main_random": {
"architectures": [
"DeepseekV3ForCausalLM"
],
"model_type": "deepseek_v3",
"text_model_type": "deepseek_v3",
"hidden_size": 2560,
"total_num_hidden_layers": 5,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"luccafong/deepseek_mtp_draft_random": {
"architectures": [
"DeepseekV3ForCausalLM"
],
"model_type": "deepseek_v3",
"text_model_type": "deepseek_v3",
"hidden_size": 2560,
"total_num_hidden_layers": 10,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"Qwen/Qwen3-Next-80B-A3B-Instruct": {
"architectures": [
"Qwen3NextForCausalLM"
],
"model_type": "qwen3_next",
"text_model_type": "qwen3_next",
"hidden_size": 2048,
"total_num_hidden_layers": 48,
"total_num_attention_heads": 16,
"head_size": 256,
"vocab_size": 151936,
"total_num_kv_heads": 2,
"num_experts": 512,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"tiny-random/qwen3-next-moe": {
"architectures": [
"Qwen3NextForCausalLM"
],
"model_type": "qwen3_next",
"text_model_type": "qwen3_next",
"hidden_size": 8,
"total_num_hidden_layers": 4,
"total_num_attention_heads": 16,
"head_size": 32,
"vocab_size": 151936,
"total_num_kv_heads": 8,
"num_experts": 32,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"zai-org/GLM-4.5": {
"architectures": [
"Glm4MoeForCausalLM"
],
"model_type": "glm4_moe",
"text_model_type": "glm4_moe",
"hidden_size": 5120,
"total_num_hidden_layers": 92,
"total_num_attention_heads": 96,
"head_size": 128,
"vocab_size": 151552,
"total_num_kv_heads": 8,
"num_experts": 160,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"baidu/ERNIE-4.5-21B-A3B-PT": {
"architectures": [
"Ernie4_5_MoeForCausalLM"
],
"model_type": "ernie4_5_moe",
"text_model_type": "ernie4_5_moe",
"hidden_size": 2560,
"total_num_hidden_layers": 28,
"total_num_attention_heads": 20,
"head_size": 128,
"vocab_size": 103424,
"total_num_kv_heads": 4,
"num_experts": 64,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"lmsys/gpt-oss-20b-bf16": {
"architectures": [
"GptOssForCausalLM"
],
"model_type": "gpt_oss",
"text_model_type": "gpt_oss",
"hidden_size": 2880,
"total_num_hidden_layers": 24,
"total_num_attention_heads": 64,
"head_size": 64,
"vocab_size": 201088,
"total_num_kv_heads": 8,
"num_experts": 32,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"deepseek-ai/DeepSeek-V3.2-Exp": {
"architectures": [
"DeepseekV32ForCausalLM"
],
"model_type": "deepseek_v32",
"text_model_type": "deepseek_v32",
"hidden_size": 7168,
"total_num_hidden_layers": 61,
"total_num_attention_heads": 128,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 128,
"num_experts": 256,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"meta-llama/Llama-4-Scout-17B-16E-Instruct": {
"architectures": [
"Llama4ForConditionalGeneration"
],
"model_type": "llama4",
"text_model_type": "llama4_text",
"hidden_size": 5120,
"total_num_hidden_layers": 48,
"total_num_attention_heads": 40,
"head_size": 128,
"vocab_size": 202048,
"total_num_kv_heads": 8,
"num_experts": 16,
"is_deepseek_mla": false,
"is_multimodal_model": true,
"dtype": "torch.bfloat16"
},
"nvidia/Llama-3_3-Nemotron-Super-49B-v1": {
"architectures": [
"DeciLMForCausalLM"
],
"model_type": "nemotron-nas",
"text_model_type": "nemotron-nas",
"hidden_size": 8192,
"total_num_hidden_layers": 80,
"total_num_attention_heads": 64,
"head_size": 128,
"vocab_size": 128256,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"XiaomiMiMo/MiMo-7B-RL": {
"architectures": [
"MiMoForCausalLM"
],
"model_type": "mimo",
"text_model_type": "mimo",
"hidden_size": 4096,
"total_num_hidden_layers": 36,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 151680,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"meituan-longcat/LongCat-Flash-Chat": {
"architectures": [
"LongcatFlashForCausalLM"
],
"model_type": "longcat_flash",
"text_model_type": "longcat_flash",
"hidden_size": 6144,
"total_num_hidden_layers": 28,
"total_num_attention_heads": 64,
"head_size": 576,
"vocab_size": 131072,
"total_num_kv_heads": 64,
"num_experts": 512,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.float32"
}
}
{
"abhigoyal/vllm-medusa-llama-68m-random": {
"architectures": [
"MedusaModel"
],
"model_type": "medusa",
"text_model_type": "medusa",
"hidden_size": 768,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 0,
"head_size": "Error: integer division or modulo by zero",
"vocab_size": 32000,
"total_num_kv_heads": 0,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "torch.float32"
},
"luccafong/deepseek_mtp_draft_random": {
"architectures": [
"DeepSeekMTPModel"
],
"model_type": "deepseek_mtp",
"text_model_type": "deepseek_mtp",
"hidden_size": 2560,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "torch.bfloat16"
},
"eagle618/eagle-deepseek-v3-random": {
"architectures": [
"EagleDeepSeekMTPModel"
],
"model_type": "eagle",
"text_model_type": "deepseek_mtp",
"hidden_size": 2560,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 576,
"vocab_size": 129280,
"total_num_kv_heads": 32,
"num_experts": 72,
"is_deepseek_mla": true,
"is_multimodal_model": false,
"dtype": "bfloat16"
},
"yuhuili/EAGLE-LLaMA3-Instruct-8B": {
"architectures": [
"EagleLlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 128256,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "float16"
},
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": {
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
"head_size": 128,
"vocab_size": 128256,
"total_num_kv_heads": 8,
"num_experts": 0,
"is_deepseek_mla": false,
"is_multimodal_model": false,
"dtype": "float16"
}
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ModelArchitectureConfig and its integration with ModelConfig."""
import json
from pathlib import Path
import pytest
from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
BASE_TRUST_REMOTE_CODE_MODELS = {
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
"XiaomiMiMo/MiMo-7B-RL",
# Excluded: Not available online right now
# "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
"meituan-longcat/LongCat-Flash-Chat",
}
BASE_MODELS_TO_TEST = [
"state-spaces/mamba-130m-hf",
"mistralai/Mamba-Codestral-7B-v0.1",
# Excluded: terratorch/torchgeo version mismatch in CPU CI environment
# (NonGeoDataset import error). Tested in model initialization tests.
# "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"Zyphra/Zamba2-7B-instruct",
# FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
# FIXME: databricks/dbrx-instruct has been deleted
# "databricks/dbrx-instruct",
"tiiuae/falcon-7b",
"tiiuae/falcon-40b",
"luccafong/deepseek_mtp_main_random",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"tiny-random/qwen3-next-moe",
"zai-org/GLM-4.5",
"baidu/ERNIE-4.5-21B-A3B-PT",
# Models using base convertor
"lmsys/gpt-oss-20b-bf16",
"deepseek-ai/DeepSeek-V3.2-Exp",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
] + list(BASE_TRUST_REMOTE_CODE_MODELS)
# (target_model, draft_model, trust_remote_code)
SPECULATIVE_MODELS = [
("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False),
("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True),
("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True),
("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True),
("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True),
]
def _load_groundtruth(filename: str) -> dict:
"""Load groundtruth JSON from the test directory."""
groundtruth_path = Path(__file__).parent / filename
with open(groundtruth_path) as f:
return json.load(f)
def _assert_model_arch_config(
model_config, expected: dict, check_head_size: bool = True
):
"""Assert model_arch_config matches expected values."""
model_arch_config = model_config.model_arch_config
assert model_arch_config.architectures == expected["architectures"]
assert model_arch_config.model_type == expected["model_type"]
assert model_arch_config.text_model_type == expected["text_model_type"]
assert model_arch_config.hidden_size == expected["hidden_size"]
assert (
model_arch_config.total_num_hidden_layers == expected["total_num_hidden_layers"]
)
assert (
model_arch_config.total_num_attention_heads
== expected["total_num_attention_heads"]
)
assert model_arch_config.vocab_size == expected["vocab_size"]
assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"]
assert model_arch_config.num_experts == expected["num_experts"]
assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"]
torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
model_config.hf_config, model_config.model, revision=model_config.revision
)
assert str(torch_dtype) == expected["dtype"]
if check_head_size:
assert model_arch_config.head_size == expected["head_size"]
def _assert_model_config_methods(
model_config, expected: dict, check_head_size: bool = True
):
"""Assert model_config methods return expected values."""
assert model_config.architectures == expected["architectures"]
assert model_config.get_vocab_size() == expected["vocab_size"]
assert model_config.get_hidden_size() == expected["hidden_size"]
assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"]
assert model_config.get_num_experts() == expected["num_experts"]
assert (
model_config.get_total_num_hidden_layers()
== expected["total_num_hidden_layers"]
)
if check_head_size:
assert model_config.get_head_size() == expected["head_size"]
@pytest.mark.parametrize("model", BASE_MODELS_TO_TEST)
def test_base_model_arch_config(model: str):
"""Test model architecture config for base models."""
groundtruth = _load_groundtruth("base_model_arch_groundtruth.json")
expected = groundtruth[model]
model_config = ModelConfig(
model, trust_remote_code=model in BASE_TRUST_REMOTE_CODE_MODELS
)
_assert_model_arch_config(model_config, expected)
_assert_model_config_methods(model_config, expected)
@pytest.mark.parametrize(
"target_model,draft_model,trust_remote_code", SPECULATIVE_MODELS
)
def test_draft_model_arch_config(
target_model: str, draft_model: str, trust_remote_code: bool
):
"""Test model architecture config for draft/speculative models."""
groundtruth = _load_groundtruth("draft_model_arch_groundtruth.json")
expected = groundtruth[draft_model]
target_model_config = ModelConfig(target_model, trust_remote_code=trust_remote_code)
speculative_config = SpeculativeConfig(
model=draft_model,
num_speculative_tokens=1,
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
)
model_config = speculative_config.draft_model_config
# For medusa models, head_size may cause division by zero before
# model_arch_config was introduced, so we conditionally check it
check_head_size = isinstance(expected["head_size"], int)
_assert_model_arch_config(model_config, expected, check_head_size=check_head_size)
_assert_model_config_methods(
model_config, expected, check_head_size=check_head_size
)
...@@ -471,12 +471,16 @@ def dummy_hf_overrides( ...@@ -471,12 +471,16 @@ def dummy_hf_overrides(
"num_kv_shared_layers": 1, "num_kv_shared_layers": 1,
} }
_hf_config = hf_config
class DummyConfig: class DummyConfig:
hf_config = _hf_config
hf_text_config = text_config hf_text_config = text_config
model_arch_config = ModelConfig.get_model_arch_config(DummyConfig)
# Only set MoE related config when the model has MoE layers. # Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls. # Otherwise all models detected as MoE by _get_transformers_backend_cls.
if ModelConfig.get_num_experts(DummyConfig) > 0: if model_arch_config.num_experts > 0:
update_dict.update( update_dict.update(
{ {
"num_experts": num_experts, "num_experts": num_experts,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging import logging
import os import os
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
......
...@@ -16,6 +16,10 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config ...@@ -16,6 +16,10 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
from vllm.config.model import ModelConfig, get_hf_text_config from vllm.config.model import ModelConfig, get_hf_text_config
from vllm.transformers_utils.model_arch_config_convertor import (
MODEL_ARCH_CONFIG_CONVERTORS,
ModelArchConfigConvertorBase,
)
from vllm.v1.metrics.perf import ( from vllm.v1.metrics.perf import (
AttentionMetrics, AttentionMetrics,
BaseConfigParser, BaseConfigParser,
...@@ -33,6 +37,12 @@ class MockModelConfig: ...@@ -33,6 +37,12 @@ class MockModelConfig:
def __init__(self, hf_config, dtype): def __init__(self, hf_config, dtype):
self.hf_config = hf_config self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config) self.hf_text_config = get_hf_text_config(hf_config)
convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get(
self.hf_config.model_type, ModelArchConfigConvertorBase
)
self.model_arch_config = convertor_cls(
self.hf_config, self.hf_text_config
).convert()
self.dtype = dtype self.dtype = dtype
self.is_attention_free = False self.is_attention_free = False
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelArchitectureConfig:
"""
Configuration for model architecture that required by vLLM runtime
"""
architectures: list[str] | None
"""List of model architecture class names (e.g., ['LlamaForCausalLM']).
It can be None upon calling `vllm_config.with_hf_config(config.text_config)`"""
model_type: str
"""Model type identifier (e.g., 'llama', 'gpt_oss')."""
text_model_type: str | None
"""Text model type identifier (e.g., 'llama4_text')."""
hidden_size: int
"""Hidden size of the model."""
total_num_hidden_layers: int
"""Number of hidden layers in the model."""
total_num_attention_heads: int
"""Number of attention heads in the model."""
head_size: int
"""Head dimension of the model."""
vocab_size: int
"""Vocabulary size of the model."""
total_num_kv_heads: int
"""Number of key value heads in the model."""
num_experts: int
"""Number of experts in the model."""
quantization_config: dict[str, Any] | None
"""Quantization configuration dictionary containing quantization parameters."""
is_deepseek_mla: bool
"""Whether the model is a DeepSeek MLA model."""
derived_max_model_len_and_key: tuple[float, str | None]
"""Derived maximum model length and key from the hf config."""
...@@ -401,6 +401,9 @@ class SpeculativeConfig: ...@@ -401,6 +401,9 @@ class SpeculativeConfig:
model_type="eagle", model_type="eagle",
) )
self.draft_model_config.hf_config = eagle_config self.draft_model_config.hf_config = eagle_config
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
if self.num_speculative_tokens is not None and hasattr( if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens" self.draft_model_config.hf_config, "num_lookahead_tokens"
......
...@@ -444,6 +444,7 @@ class VllmConfig: ...@@ -444,6 +444,7 @@ class VllmConfig:
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
model_config.hf_config = hf_config model_config.hf_config = hf_config
model_config.model_arch_config = model_config.get_model_arch_config()
return replace(self, model_config=model_config) return replace(self, model_config=model_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import final
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.model_arch import (
ModelArchitectureConfig,
)
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
try_get_safetensors_metadata,
)
from vllm.utils.torch_utils import common_broadcastable_dtype
logger = init_logger(__name__)
class ModelArchConfigConvertorBase:
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
self.hf_config = hf_config
self.hf_text_config = hf_text_config
def get_architectures(self) -> list[str]:
return getattr(self.hf_config, "architectures", [])
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_hidden_layers", 0)
def get_total_num_attention_heads(self) -> int:
return getattr(self.hf_text_config, "num_attention_heads", 0)
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)
def get_hidden_size(self) -> int:
return getattr(self.hf_text_config, "hidden_size", 0)
def get_head_size(self) -> int:
if self.is_deepseek_mla():
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0)
if not envs.VLLM_MLA_DISABLE:
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else:
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
# NOTE: Some configs may set head_dim=None in the config
if getattr(self.hf_text_config, "head_dim", None) is not None:
return self.hf_text_config.head_dim
# NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head`
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head
# FIXME(woosuk): This may not be true for all models.
return (
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
)
def get_total_num_kv_heads(self) -> int:
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
"moe_num_experts", # Dbrx
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
return num_experts[0]
# Coerce to 0 if explicitly set to None
return num_experts or 0
@final
@classmethod
def get_torch_dtype(
cls, hf_config: PretrainedConfig, model_id: str, revision: str | None
):
# NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None.
config_dtype = getattr(hf_config, "dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define dtype
if config_dtype is None:
config_dtype = getattr(hf_config.get_text_config(), "dtype", None)
if config_dtype is None and hasattr(hf_config, "vision_config"):
config_dtype = getattr(hf_config.vision_config, "dtype", None)
if config_dtype is None and hasattr(hf_config, "encoder_config"):
config_dtype = getattr(hf_config.encoder_config, "dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
for file_mt in files_mt.values()
for dtype_str in file_mt.parameter_count
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
return common_broadcastable_dtype(param_dtypes)
if config_dtype is None:
config_dtype = torch.float32
return config_dtype
def _normalize_quantization_config(self, config: PretrainedConfig):
quant_cfg = getattr(config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(config, "compression_config", None)
else:
# Set quant_method for ModelOpt models.
producer_name = quant_cfg.get("producer", {}).get("name")
if producer_name == "modelopt":
quant_algo = quant_cfg.get("quantization", {}).get("quant_algo")
if quant_algo is not None:
quant_algo_upper = str(quant_algo).upper()
if quant_algo_upper in {
"FP8",
"FP8_PER_CHANNEL_PER_TOKEN",
"FP8_PB_WO",
}:
quant_cfg["quant_method"] = "modelopt"
elif quant_algo_upper == "NVFP4":
quant_cfg["quant_method"] = "modelopt_fp4"
else:
raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}")
if quant_cfg is not None:
# Use the community standard 'quant_method'
quant_method = quant_cfg.get("quant_method", "").lower()
# Normalize library names
quant_method = quant_method.replace(
"compressed_tensors", "compressed-tensors"
)
quant_cfg["quant_method"] = quant_method
return quant_cfg
def get_quantization_config(self):
quant_cfg = self._normalize_quantization_config(self.hf_config)
if quant_cfg is None and (
text_config := getattr(self.hf_config, "text_config", None)
):
# Check the text config as well for multi-modal models.
quant_cfg = self._normalize_quantization_config(text_config)
return quant_cfg
def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in (
"deepseek_v2",
"deepseek_v3",
"deepseek_v32",
"deepseek_mtp",
"kimi_k2",
"kimi_linear",
"longcat_flash",
"pangu_ultra_moe",
"pangu_ultra_moe_mtp",
):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == "eagle":
# if the model is an EAGLE module, check for the
# underlying architecture
return (
self.hf_text_config.model.model_type
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
and self.hf_text_config.kv_lora_rank is not None
)
return False
def derive_max_model_len_and_key(self) -> tuple[float, str | None]:
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Whisper
"max_target_positions",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
# Choose the smallest "max_length" from the possible keys
max_len_key = None
for key in possible_keys:
max_len = getattr(self.hf_text_config, key, None)
if max_len is not None:
if max_len < derived_max_model_len:
max_len_key = key
derived_max_model_len = min(derived_max_model_len, max_len)
# For Command-R / Cohere, Cohere2 / Aya Vision models
if tmp_max_len := getattr(self.hf_text_config, "model_max_length", None):
max_len_key = "model_max_length"
derived_max_model_len = tmp_max_len
return derived_max_model_len, max_len_key
def convert(self) -> ModelArchitectureConfig:
model_arch_config = ModelArchitectureConfig(
architectures=self.get_architectures(),
model_type=self.hf_config.model_type,
text_model_type=getattr(self.hf_text_config, "model_type", None),
hidden_size=self.get_hidden_size(),
total_num_hidden_layers=self.get_num_hidden_layers(),
total_num_attention_heads=self.get_total_num_attention_heads(),
head_size=self.get_head_size(),
vocab_size=self.get_vocab_size(),
total_num_kv_heads=self.get_total_num_kv_heads(),
num_experts=self.get_num_experts(),
quantization_config=self.get_quantization_config(),
is_deepseek_mla=self.is_deepseek_mla(),
derived_max_model_len_and_key=self.derive_max_model_len_and_key(),
)
return model_arch_config
class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
def get_total_num_kv_heads(self) -> int:
return 0
class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
def get_total_num_kv_heads(self) -> int:
return 0
class MedusaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
def get_total_num_kv_heads(self) -> int:
return 0
class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return getattr(self.hf_text_config, "attention_head_dim", 0)
class FalconModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_kv_heads(self) -> int:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
new_decoder_arch_falcon = getattr(
self.hf_text_config, "new_decoder_architecture", False
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
return 1
# Use the base implementation which checks n_head_kv, num_kv_heads, etc.
return super().get_total_num_kv_heads()
class MPTModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_kv_heads(self) -> int:
if "kv_n_heads" in self.hf_text_config.attn_config:
return self.hf_text_config.attn_config["kv_n_heads"]
return self.hf_text_config.num_attention_heads
class DbrxModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_kv_heads(self) -> int:
return getattr(
self.hf_text_config.attn_config,
"kv_n_heads",
self.hf_text_config.num_attention_heads,
)
class NemotronNasModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_kv_heads(self) -> int:
for block in self.hf_text_config.block_configs:
if not block.attention.no_op:
return (
self.hf_text_config.num_attention_heads
// block.attention.n_heads_in_group
)
raise RuntimeError(
"Could not determine the number of key-value attention heads "
"from model configuration. "
f"Architecture: {self.get_architectures()}. "
"This usually indicates an unsupported model architecture or "
"missing configuration. "
"Please check if your model is supported at: "
"https://docs.vllm.ai/en/latest/models/supported_models.html"
)
class DeepSeekMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class MimoMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class GLM4MoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class ErnieMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class PanguUltraMoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"mamba": MambaModelArchConfigConvertor,
"falcon_mamba": MambaModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor,
"medusa": MedusaModelArchConfigConvertor,
"zamba2": Zamba2ModelArchConfigConvertor,
"mpt": MPTModelArchConfigConvertor,
"dbrx": DbrxModelArchConfigConvertor,
"falcon": FalconModelArchConfigConvertor,
"RefinedWeb": FalconModelArchConfigConvertor,
"RefinedWebModel": FalconModelArchConfigConvertor,
"nemotron-nas": NemotronNasModelArchConfigConvertor,
"deepseek_mtp": DeepSeekMTPModelArchConfigConvertor,
"qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor,
"mimo_mtp": MimoMTPModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment