test_initialization.py 5.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
from unittest.mock import patch

import pytest
from transformers import PretrainedConfig

from vllm import LLM
10
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
11
12
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
13
from vllm.v1.engine.core import EngineCore as V1EngineCore
14

15
from ..utils import create_new_process_for_each_test
16
from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels
17
18


19
@create_new_process_for_each_test()
20
21
22
23
24
def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
                   EXAMPLE_MODELS: HfExampleModels):
    """The reason for using create_new_process_for_each_test is to avoid
    the WARNING:
        "We must use the 'spawn' multiprocessing start method. Overriding
25
        VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
26
    The spawn process causes the _initialize_kv_caches_v1 function below to
27
28
    become ineffective.
    """
29
30

    model_info = EXAMPLE_MODELS.get_hf_info(model_arch)
31
32
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
33

34
    # FIXME: Possible memory leak in the previous tests?
35
36
    if model_arch in ("Glm4vForConditionalGeneration",
                      "GraniteSpeechForConditionalGeneration",
37
                      "KimiVLForConditionalGeneration"):
38
39
        pytest.skip("Avoid OOM")

zhiweiz's avatar
zhiweiz committed
40
41
42
43
44
    if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
        from vllm.model_executor.models.llama4 import Llama4ForCausalLM
        from vllm.model_executor.models.registry import ModelRegistry
        ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)

45
    # Avoid OOM and reduce initialization time by only using 1 layer
46
    def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
47
        hf_config.update(model_info.hf_overrides)
48

49
        text_config = hf_config.get_text_config()
50

51
        # Ensure at least 2 expert per group
52
        # Since `grouped_topk` assumes top-2
53
54
        n_group = getattr(text_config, 'n_group', None)
        num_experts = n_group * 2 if n_group is not None else 2
55

56
57
58
59
60
        # we use three layers for Gemma-3n to check
        # both normal layer and kv_shared_layer
        num_hidden_layers = (3 if model_arch
                             == "Gemma3nForConditionalGeneration" else 1)

61
62
        text_config.update({
            "num_layers": 1,
63
            "num_hidden_layers": num_hidden_layers,
64
            "num_experts": num_experts,
65
            "num_experts_per_tok": 2,
66
67
68
69
70
            "num_local_experts": num_experts,
            # Otherwise there will not be any expert layers
            "first_k_dense_replace": 0,
            # To avoid OOM on DeepSeek-V3
            "n_routed_experts": num_experts,
71
72
            # For Gemma-3n
            "num_kv_shared_layers": 1,
73
74
        })

75
76
77
78
79
80
        if hasattr(hf_config, "vision_config"):
            hf_config.vision_config.update({
                "num_layers": 1,
                "num_hidden_layers": 1,
            })

81
82
83
84
85
86
87
        # e.g.: ibm-granite/granite-speech-3.3-2b
        if hasattr(hf_config, "encoder_config"):
            hf_config.encoder_config.update({
                "num_layers": 1,
                "num_hidden_layers": 1,
            })

88
89
90
        return hf_config

    # Avoid calling model.forward()
91
    def _initialize_kv_caches_v0(self) -> None:
92
93
94
        self.cache_config.num_gpu_blocks = 0
        self.cache_config.num_cpu_blocks = 0

95
96
97
98
99
    def _initialize_kv_caches_v1(self, vllm_config):
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
        scheduler_kv_cache_config = get_kv_cache_config(
            vllm_config,
            kv_cache_specs[0],
100
            10 * GiB_bytes,
101
102
103
104
        )

        # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
        return 1, 0, scheduler_kv_cache_config
105
106
107
108

    with (patch.object(V0LLMEngine, "_initialize_kv_caches",
                       _initialize_kv_caches_v0),
          patch.object(V1EngineCore, "_initialize_kv_caches",
109
110
111
                       _initialize_kv_caches_v1), monkeypatch.context() as m):
        if model_info.v0_only:
            m.setenv("VLLM_USE_V1", "0")
112
113
114
        if model_arch == "Phi4FlashForCausalLM":
            # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
            m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
115
        LLM(
116
            model_info.default,
117
118
            tokenizer=model_info.tokenizer,
            tokenizer_mode=model_info.tokenizer_mode,
119
            revision=model_info.revision,
120
121
122
123
            speculative_config={
                "model": model_info.speculative_model,
                "num_speculative_tokens": 1,
            } if model_info.speculative_model else None,
124
            trust_remote_code=model_info.trust_remote_code,
125
            max_model_len=model_info.max_model_len,
126
127
            # these tests seem to produce leftover memory
            gpu_memory_utilization=0.80,
128
129
130
            load_format="dummy",
            hf_overrides=hf_overrides,
        )
131
132
133
134
135
136
137
138
139
140
141
142


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
    can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)


@pytest.mark.parametrize("model_arch",
                         AUTO_EXAMPLE_MODELS.get_supported_archs())
def test_implicit_converted_models(model_arch: str,
                                   monkeypatch: pytest.MonkeyPatch):
    can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS)