test_initialization.py 2.74 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
from unittest.mock import patch

import pytest
from transformers import PretrainedConfig

from vllm import LLM
9
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
10
11
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
12
from vllm.v1.engine.core import EngineCore as V1EngineCore
13
14
15
16
17

from .registry import HF_EXAMPLE_MODELS


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
18
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
19
    model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
20
21
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
22

23
    # Avoid OOM and reduce initialization time by only using 1 layer
24
    def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
25
        hf_config.update(model_info.hf_overrides)
26

27
        text_config = hf_config.get_text_config()
28
29
30
31
32
33
34
35
36

        text_config.update({
            "num_layers": 1,
            "num_hidden_layers": 1,
            "num_experts": 2,
            "num_experts_per_tok": 2,
            "num_local_experts": 2,
        })

37
38
39
40
41
42
        if hasattr(hf_config, "vision_config"):
            hf_config.vision_config.update({
                "num_layers": 1,
                "num_hidden_layers": 1,
            })

43
44
45
        return hf_config

    # Avoid calling model.forward()
46
    def _initialize_kv_caches_v0(self) -> None:
47
48
49
        self.cache_config.num_gpu_blocks = 0
        self.cache_config.num_cpu_blocks = 0

50
51
52
53
54
    def _initialize_kv_caches_v1(self, vllm_config):
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
        scheduler_kv_cache_config = get_kv_cache_config(
            vllm_config,
            kv_cache_specs[0],
55
            10 * GiB_bytes,
56
57
58
59
        )

        # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
        return 1, 0, scheduler_kv_cache_config
60
61
62
63

    with (patch.object(V0LLMEngine, "_initialize_kv_caches",
                       _initialize_kv_caches_v0),
          patch.object(V1EngineCore, "_initialize_kv_caches",
64
65
66
                       _initialize_kv_caches_v1), monkeypatch.context() as m):
        if model_info.v0_only:
            m.setenv("VLLM_USE_V1", "0")
67
        LLM(
68
            model_info.default,
69
70
            tokenizer=model_info.tokenizer,
            tokenizer_mode=model_info.tokenizer_mode,
71
72
73
74
            speculative_config={
                "model": model_info.speculative_model,
                "num_speculative_tokens": 1,
            } if model_info.speculative_model else None,
75
            trust_remote_code=model_info.trust_remote_code,
76
            max_model_len=model_info.max_model_len,
77
78
79
            load_format="dummy",
            hf_overrides=hf_overrides,
        )