Unverified Commit 8d825b87 authored by Tihomir Elek's avatar Tihomir Elek Committed by GitHub
Browse files

[Bug] Fix TypeError when hf_config.architectures is None during model loading (#38849)


Signed-off-by: default avatarTihomir Elek <tiho.elek@gmail.com>
parent 1b19bd75
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
import logging import logging
import os import os
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import pydantic import pydantic
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
import vllm.config.vllm as vllm_config_module
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
...@@ -45,6 +47,81 @@ def test_compile_config_repr_succeeds(): ...@@ -45,6 +47,81 @@ def test_compile_config_repr_succeeds():
assert "inductor_passes" in val assert "inductor_passes" in val
@pytest.mark.skip_global_cleanup
def test_with_hf_config_populates_missing_architectures_from_causal_lm_mapping(
monkeypatch,
):
monkeypatch.setattr(
vllm_config_module,
"replace",
lambda self, **kwargs: SimpleNamespace(**kwargs),
)
cfg = SimpleNamespace(
model_config=SimpleNamespace(
is_multimodal_model=False,
hf_config=SimpleNamespace(),
get_model_arch_config=lambda: "arch-config",
)
)
hf_config = SimpleNamespace(model_type="mistral", architectures=None)
updated = VllmConfig.with_hf_config(cfg, hf_config)
assert updated.model_config.hf_config.architectures == ["MistralForCausalLM"]
assert hf_config.architectures is None
@pytest.mark.skip_global_cleanup
def test_with_hf_config_preserves_explicit_architectures_override(monkeypatch):
monkeypatch.setattr(
vllm_config_module,
"replace",
lambda self, **kwargs: SimpleNamespace(**kwargs),
)
cfg = SimpleNamespace(
model_config=SimpleNamespace(
is_multimodal_model=False,
hf_config=SimpleNamespace(),
get_model_arch_config=lambda: "arch-config",
)
)
hf_config = SimpleNamespace(model_type="mistral", architectures=None)
updated = VllmConfig.with_hf_config(
cfg,
hf_config,
architectures=["Ministral3ForCausalLM"],
)
assert updated.model_config.hf_config.architectures == ["Ministral3ForCausalLM"]
@pytest.mark.skip_global_cleanup
def test_with_hf_config_leaves_unknown_model_type_without_architectures(
monkeypatch,
):
monkeypatch.setattr(
vllm_config_module,
"replace",
lambda self, **kwargs: SimpleNamespace(**kwargs),
)
cfg = SimpleNamespace(
model_config=SimpleNamespace(
is_multimodal_model=False,
hf_config=SimpleNamespace(),
get_model_arch_config=lambda: "arch-config",
)
)
hf_config = SimpleNamespace(
model_type="not_a_real_model",
architectures=None,
)
updated = VllmConfig.with_hf_config(cfg, hf_config)
assert updated.model_config.hf_config.architectures is None
def test_async_scheduling_with_pipeline_parallelism_is_allowed(): def test_async_scheduling_with_pipeline_parallelism_is_allowed():
cfg = VllmConfig( cfg = VllmConfig(
scheduler_config=SchedulerConfig( scheduler_config=SchedulerConfig(
......
...@@ -559,6 +559,16 @@ class VllmConfig: ...@@ -559,6 +559,16 @@ class VllmConfig:
if architectures is not None: if architectures is not None:
hf_config = copy.deepcopy(hf_config) hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures hf_config.architectures = architectures
elif hf_config.architectures is None:
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
)
if hf_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = [
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[hf_config.model_type]
]
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
......
...@@ -175,7 +175,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() ...@@ -175,7 +175,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", None) or []
model_cls, arch = model_config.registry.resolve_model_cls( model_cls, arch = model_config.registry.resolve_model_cls(
architectures, architectures,
...@@ -215,7 +215,7 @@ def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -215,7 +215,7 @@ def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
model_config.runner_type, model_config.runner_type,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.model_impl, model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", [])), tuple(getattr(model_config.hf_config, "architectures", None) or []),
) )
) )
if key in _MODEL_ARCH_BY_HASH: if key in _MODEL_ARCH_BY_HASH:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment