test_engine_args.py 2.05 KB
Newer Older
1
2
3
4
5
6
import pytest

from vllm import envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
7
from vllm.utils import FlexibleArgumentParser
8
9
10
11
12
13
14
15

if not envs.VLLM_USE_V1:
    pytest.skip(
        "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
        allow_module_level=True,
    )


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def test_prefix_caching_from_cli():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    args = parser.parse_args([])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert (engine_args.enable_prefix_caching
            ), "V1 turns on prefix caching by default."

    # Turn it off possible with flag.
    args = parser.parse_args(["--no-enable-prefix-caching"])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert not engine_args.enable_prefix_caching

    # Turn it on with flag.
    args = parser.parse_args(["--enable-prefix-caching"])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert engine_args.enable_prefix_caching


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def test_defaults():
    engine_args = EngineArgs(model="facebook/opt-125m")

    # Assert V1 defaults
    assert (engine_args.enable_prefix_caching
            ), "V1 turns on prefix caching by default"


def test_defaults_with_usage_context():
    engine_args = EngineArgs(model="facebook/opt-125m")
    vllm_config: VllmConfig = engine_args.create_engine_config(
        UsageContext.LLM_CLASS)

    assert vllm_config.scheduler_config.max_num_seqs == 1024
    assert vllm_config.scheduler_config.max_num_batched_tokens == 8192

    engine_args = EngineArgs(model="facebook/opt-125m")
    vllm_config = engine_args.create_engine_config(
        UsageContext.OPENAI_API_SERVER)
    assert vllm_config.scheduler_config.max_num_seqs == 1024
    assert vllm_config.scheduler_config.max_num_batched_tokens == 2048


def test_prefix_cache_disabled_with_multimodel():
    engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")

    vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
    assert not vllm_config.cache_config.enable_prefix_caching