test_engine_args.py 3.24 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

zhuwenwen's avatar
zhuwenwen committed
3
import os
4
5
from argparse import ArgumentError

6
7
8
9
10
11
import pytest

from vllm import envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
12
from vllm.utils import FlexibleArgumentParser
zhuwenwen's avatar
zhuwenwen committed
13
from ...utils import models_path_prefix
14
15
16
17
18
19
20
21

if not envs.VLLM_USE_V1:
    pytest.skip(
        "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
        allow_module_level=True,
    )


22
23
24
def test_prefix_caching_from_cli():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    args = parser.parse_args([])
25
26
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert (vllm_config.cache_config.enable_prefix_caching
27
28
29
30
            ), "V1 turns on prefix caching by default."

    # Turn it off possible with flag.
    args = parser.parse_args(["--no-enable-prefix-caching"])
31
32
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert not vllm_config.cache_config.enable_prefix_caching
33
34
35

    # Turn it on with flag.
    args = parser.parse_args(["--enable-prefix-caching"])
36
37
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.enable_prefix_caching
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    # default hash algorithm is "builtin"
    assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"

    # set hash algorithm to sha256
    args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"

    # set hash algorithm to builtin
    args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"

    # an invalid hash algorithm raises an error
    parser.exit_on_error = False
    with pytest.raises(ArgumentError):
        args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])

57

58
def test_defaults_with_usage_context():
zhuwenwen's avatar
zhuwenwen committed
59
    engine_args = EngineArgs(model=os.path.join(models_path_prefix, "facebook/opt-125m"))
60
61
62
    vllm_config: VllmConfig = engine_args.create_engine_config(
        UsageContext.LLM_CLASS)

63
64
65
66
67
68
    from vllm.platforms import current_platform
    device_name = current_platform.get_device_name().lower()
    if "h100" in device_name or "h200" in device_name:
        # For H100 and H200, we use larger default values.
        default_llm_tokens = 16384
        default_server_tokens = 8192
69
        default_max_num_seqs = 1024
70
71
72
    else:
        default_llm_tokens = 8192
        default_server_tokens = 2048
73
        default_max_num_seqs = 256
74

75
    assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
76
    assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens  # noqa: E501
77

zhuwenwen's avatar
zhuwenwen committed
78
    engine_args = EngineArgs(model=os.path.join(models_path_prefix, "facebook/opt-125m"))
79
80
    vllm_config = engine_args.create_engine_config(
        UsageContext.OPENAI_API_SERVER)
81
    assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
82
    assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens  # noqa: E501