test_oracle.py 2.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os

import pytest

import vllm.envs as envs
from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs

MODEL = "meta-llama/Llama-3.2-1B-Instruct"


def test_reject_bad_config(monkeypatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "0")


def test_unsupported_configs(monkeypatch):

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(
                model=MODEL,
27
28
29
                speculative_config={
                    "model": MODEL,
                },
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
            ).create_engine_config()


def test_enable_by_default_fallback(monkeypatch):
    with monkeypatch.context() as m:
        if os.getenv("VLLM_USE_V1", None):
            m.delenv("VLLM_USE_V1")

        # Should default to V1 for supported config.
        _ = AsyncEngineArgs(
            model=MODEL,
            enforce_eager=True,
        ).create_engine_config()
        assert envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")


def test_v1_llm_by_default(monkeypatch):
    with monkeypatch.context() as m:
        if os.getenv("VLLM_USE_V1", None):
            m.delenv("VLLM_USE_V1")

        # Should default to V1 for supported config.
53
54
55
        llm = LLM(MODEL, enforce_eager=True, enable_lora=True)
        print(llm.generate("Hello my name is"))
        assert hasattr(llm.llm_engine, "engine_core")
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        m.delenv("VLLM_USE_V1")


def test_v1_attn_backend(monkeypatch):
    with monkeypatch.context() as m:
        if os.getenv("VLLM_USE_V1", None):
            m.delenv("VLLM_USE_V1")
        m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")

        # Fall back to V0.
        _ = AsyncEngineArgs(model=MODEL).create_engine_config()
        assert not envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")

        # Reject if V1.
        m.setenv("VLLM_USE_V1", "1")
        with pytest.raises(NotImplementedError):
            AsyncEngineArgs(model=MODEL).create_engine_config()
        m.delenv("VLLM_USE_V1")

        m.setenv("VLLM_ATTENTION_BACKEND", "FLASHMLA")
        _ = AsyncEngineArgs(model=MODEL).create_engine_config()
        assert envs.VLLM_USE_V1
        m.delenv("VLLM_USE_V1")