test_oracle.py 5.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
import os

import pytest

import vllm.envs as envs
from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
zhuwenwen's avatar
zhuwenwen committed
11
from vllm.platforms import current_platform
zhuwenwen's avatar
zhuwenwen committed
12
from ..utils import models_path_prefix
13
14

UNSUPPORTED_MODELS_V1 = [
zhuwenwen's avatar
zhuwenwen committed
15
16
    os.path.join(models_path_prefix,"openai/whisper-large-v3"),  # transcription
    os.path.join(models_path_prefix,"facebook/bart-large-cnn"),  # encoder decoder
zhuwenwen's avatar
zhuwenwen committed
17
    os.path.join(models_path_prefix,"state-spaces/mamba-130m-hf"),  # mamba1
zhuwenwen's avatar
zhuwenwen committed
18
    os.path.join(models_path_prefix,"BAAI/bge-m3"),  # embedding
19
20
]

zhuwenwen's avatar
zhuwenwen committed
21
MODEL = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


@pytest.mark.parametrize("model", UNSUPPORTED_MODELS_V1)
def test_reject_unsupported_models(monkeypatch, model):
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        args = AsyncEngineArgs(model=model)

        with pytest.raises(NotImplementedError):
            _ = args.create_engine_config()
        m.delenv("VLLM_USE_V1")


def test_reject_bad_config(monkeypatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "0")


def test_unsupported_configs(monkeypatch):

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
                kv_cache_dtype="fp8",
            ).create_engine_config()

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
54
55
56
                speculative_config={
                    "model": MODEL,
                },
57
58
59
60
61
            ).create_engine_config()

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
62
63
                guided_decoding_backend="lm-format-enforcer",
                guided_decoding_disable_fallback=True,
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            ).create_engine_config()

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
                preemption_mode="swap",
            ).create_engine_config()

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
                disable_async_output_proc=True,
            ).create_engine_config()

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
                num_scheduler_steps=5,
            ).create_engine_config()

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
                scheduler_delay_factor=1.2,
            ).create_engine_config()


def test_enable_by_default_fallback(monkeypatch):
    with monkeypatch.context() as m:
        if os.getenv("VLLM_USE_V1", None):
            m.delenv("VLLM_USE_V1")

        # Should default to V1 for supported config.
        _ = AsyncEngineArgs(
            model=MODEL,
            enforce_eager=True,
        ).create_engine_config()
        assert envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")

        # Should fall back to V0 for supported model.
        _ = AsyncEngineArgs(
            model=UNSUPPORTED_MODELS_V1[0]).create_engine_config()
        assert not envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")


def test_v1_llm_by_default(monkeypatch):
    with monkeypatch.context() as m:
        if os.getenv("VLLM_USE_V1", None):
            m.delenv("VLLM_USE_V1")

        # Should default to V1 for supported config.
117
        model = LLM(MODEL, enforce_eager=True, enable_lora=True)
118
119
120
121
122
123
124
        print(model.generate("Hello my name is"))
        assert hasattr(model.llm_engine, "engine_core")
        m.delenv("VLLM_USE_V1")


def test_v1_attn_backend(monkeypatch):
    with monkeypatch.context() as m:
zhuwenwen's avatar
zhuwenwen committed
125
126
127
128
        if not current_platform.is_rocm():
            if os.getenv("VLLM_USE_V1", None):
                m.delenv("VLLM_USE_V1")
            m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

        # Fall back to V0.
        _ = AsyncEngineArgs(model=MODEL).create_engine_config()
        assert not envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")

        # Reject if V1.
        m.setenv("VLLM_USE_V1", "1")
        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(model=MODEL).create_engine_config()
        m.delenv("VLLM_USE_V1")

        m.setenv("VLLM_ATTENTION_BACKEND", "FLASHMLA")
        _ = AsyncEngineArgs(model=MODEL).create_engine_config()
        assert envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")


def test_reject_using_constructor_directly(monkeypatch):
    with monkeypatch.context() as m:
        if os.getenv("VLLM_USE_V1", None):
            m.delenv("VLLM_USE_V1")

        # Sets VLLM_USE_V1=1.
        vllm_config = AsyncEngineArgs(model=MODEL).create_engine_config()

        # This uses the V0 constructor directly.
        with pytest.raises(ValueError):
            AsyncLLMEngine(vllm_config,
                           AsyncLLMEngine._get_executor_cls(vllm_config),
                           log_stats=True)

zhuwenwen's avatar
zhuwenwen committed
161
        m.delenv("VLLM_USE_V1")