test_initialization.py 3.18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
from unittest.mock import patch

import pytest
from transformers import PretrainedConfig

from vllm import LLM
10
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
11
12
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
13
from vllm.v1.engine.core import EngineCore as V1EngineCore
14
15
16
17
18

from .registry import HF_EXAMPLE_MODELS


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
19
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
20
    model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
21
22
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
23

24
25
26
27
    # FIXME: Possible memory leak in the previous tests?
    if model_arch == "GraniteSpeechForConditionalGeneration":
        pytest.skip("Avoid OOM")

28
    # Avoid OOM and reduce initialization time by only using 1 layer
29
    def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
30
        hf_config.update(model_info.hf_overrides)
31

32
        text_config = hf_config.get_text_config()
33
34
35
36
37
38
39
40
41

        text_config.update({
            "num_layers": 1,
            "num_hidden_layers": 1,
            "num_experts": 2,
            "num_experts_per_tok": 2,
            "num_local_experts": 2,
        })

42
43
44
45
46
47
        if hasattr(hf_config, "vision_config"):
            hf_config.vision_config.update({
                "num_layers": 1,
                "num_hidden_layers": 1,
            })

48
49
50
51
52
53
54
        # e.g.: ibm-granite/granite-speech-3.3-2b
        if hasattr(hf_config, "encoder_config"):
            hf_config.encoder_config.update({
                "num_layers": 1,
                "num_hidden_layers": 1,
            })

55
56
57
        return hf_config

    # Avoid calling model.forward()
58
    def _initialize_kv_caches_v0(self) -> None:
59
60
61
        self.cache_config.num_gpu_blocks = 0
        self.cache_config.num_cpu_blocks = 0

62
63
64
65
66
    def _initialize_kv_caches_v1(self, vllm_config):
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
        scheduler_kv_cache_config = get_kv_cache_config(
            vllm_config,
            kv_cache_specs[0],
67
            10 * GiB_bytes,
68
69
70
71
        )

        # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
        return 1, 0, scheduler_kv_cache_config
72
73
74
75

    with (patch.object(V0LLMEngine, "_initialize_kv_caches",
                       _initialize_kv_caches_v0),
          patch.object(V1EngineCore, "_initialize_kv_caches",
76
77
78
                       _initialize_kv_caches_v1), monkeypatch.context() as m):
        if model_info.v0_only:
            m.setenv("VLLM_USE_V1", "0")
79
        LLM(
80
            model_info.default,
81
82
            tokenizer=model_info.tokenizer,
            tokenizer_mode=model_info.tokenizer_mode,
83
84
85
86
            speculative_config={
                "model": model_info.speculative_model,
                "num_speculative_tokens": 1,
            } if model_info.speculative_model else None,
87
            trust_remote_code=model_info.trust_remote_code,
88
            max_model_len=model_info.max_model_len,
89
90
91
            load_format="dummy",
            hf_overrides=hf_overrides,
        )