Unverified Commit 8d825b87 authored by Tihomir Elek's avatar Tihomir Elek Committed by GitHub
Browse files

[Bug] Fix TypeError when hf_config.architectures is None during model loading (#38849)


Signed-off-by: default avatarTihomir Elek <tiho.elek@gmail.com>
parent 1b19bd75
......@@ -4,12 +4,14 @@
import logging
import os
from dataclasses import MISSING, Field, asdict, dataclass, field
from types import SimpleNamespace
from unittest.mock import patch
import pydantic
import pytest
from pydantic import ValidationError
import vllm.config.vllm as vllm_config_module
from vllm.compilation.backends import VllmBackend
from vllm.config import (
CompilationConfig,
......@@ -45,6 +47,81 @@ def test_compile_config_repr_succeeds():
assert "inductor_passes" in val
@pytest.mark.skip_global_cleanup
def test_with_hf_config_populates_missing_architectures_from_causal_lm_mapping(
monkeypatch,
):
monkeypatch.setattr(
vllm_config_module,
"replace",
lambda self, **kwargs: SimpleNamespace(**kwargs),
)
cfg = SimpleNamespace(
model_config=SimpleNamespace(
is_multimodal_model=False,
hf_config=SimpleNamespace(),
get_model_arch_config=lambda: "arch-config",
)
)
hf_config = SimpleNamespace(model_type="mistral", architectures=None)
updated = VllmConfig.with_hf_config(cfg, hf_config)
assert updated.model_config.hf_config.architectures == ["MistralForCausalLM"]
assert hf_config.architectures is None
@pytest.mark.skip_global_cleanup
def test_with_hf_config_preserves_explicit_architectures_override(monkeypatch):
monkeypatch.setattr(
vllm_config_module,
"replace",
lambda self, **kwargs: SimpleNamespace(**kwargs),
)
cfg = SimpleNamespace(
model_config=SimpleNamespace(
is_multimodal_model=False,
hf_config=SimpleNamespace(),
get_model_arch_config=lambda: "arch-config",
)
)
hf_config = SimpleNamespace(model_type="mistral", architectures=None)
updated = VllmConfig.with_hf_config(
cfg,
hf_config,
architectures=["Ministral3ForCausalLM"],
)
assert updated.model_config.hf_config.architectures == ["Ministral3ForCausalLM"]
@pytest.mark.skip_global_cleanup
def test_with_hf_config_leaves_unknown_model_type_without_architectures(
monkeypatch,
):
monkeypatch.setattr(
vllm_config_module,
"replace",
lambda self, **kwargs: SimpleNamespace(**kwargs),
)
cfg = SimpleNamespace(
model_config=SimpleNamespace(
is_multimodal_model=False,
hf_config=SimpleNamespace(),
get_model_arch_config=lambda: "arch-config",
)
)
hf_config = SimpleNamespace(
model_type="not_a_real_model",
architectures=None,
)
updated = VllmConfig.with_hf_config(cfg, hf_config)
assert updated.model_config.hf_config.architectures is None
def test_async_scheduling_with_pipeline_parallelism_is_allowed():
cfg = VllmConfig(
scheduler_config=SchedulerConfig(
......
......@@ -559,6 +559,16 @@ class VllmConfig:
if architectures is not None:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures
elif hf_config.architectures is None:
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
)
if hf_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = [
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[hf_config.model_type]
]
model_config = copy.deepcopy(self.model_config)
......
......@@ -175,7 +175,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
architectures = getattr(model_config.hf_config, "architectures", [])
architectures = getattr(model_config.hf_config, "architectures", None) or []
model_cls, arch = model_config.registry.resolve_model_cls(
architectures,
......@@ -215,7 +215,7 @@ def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
model_config.runner_type,
model_config.trust_remote_code,
model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", [])),
tuple(getattr(model_config.hf_config, "architectures", None) or []),
)
)
if key in _MODEL_ARCH_BY_HASH:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment