test_engine_args.py 3.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from argparse import ArgumentError

6
7
8
9
10
import pytest

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
11
from vllm.utils.argparse_utils import FlexibleArgumentParser
12
from vllm.utils.hashing import _xxhash
13
14


15
16
17
def test_prefix_caching_from_cli():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    args = parser.parse_args([])
18
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
19
20
21
    assert vllm_config.cache_config.enable_prefix_caching, (
        "V1 turns on prefix caching by default."
    )
22
23
24

    # Turn it off possible with flag.
    args = parser.parse_args(["--no-enable-prefix-caching"])
25
26
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert not vllm_config.cache_config.enable_prefix_caching
27
28
29

    # Turn it on with flag.
    args = parser.parse_args(["--enable-prefix-caching"])
30
31
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.enable_prefix_caching
32

33
    # default hash algorithm is "builtin"
34
35
36
37
38
    assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"

    # set hash algorithm to sha256_cbor
    args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
39
    assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256_cbor"
40
41
42
43
44
45
46
47
48
49
50

    # set hash algorithm to sha256
    args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"

    # an invalid hash algorithm raises an error
    parser.exit_on_error = False
    with pytest.raises(ArgumentError):
        args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])

51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
def test_prefix_caching_xxhash_from_cli():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())

    # set hash algorithm to xxhash (pickle)
    args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"

    # set hash algorithm to xxhash_cbor
    args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
    vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
    assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"


67
68
def test_defaults_with_usage_context():
    engine_args = EngineArgs(model="facebook/opt-125m")
69
    vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
70

71
    from vllm.platforms import current_platform
72
    from vllm.utils.mem_constants import GiB_bytes
73

74
    device_memory = current_platform.get_device_total_memory()
75
    device_name = current_platform.get_device_name().lower()
76
77
    if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
        # For GPUs like H100, H200, and MI300x with >= 70GB memory
78
79
        default_llm_tokens = 16384
        default_server_tokens = 8192
80
        default_max_num_seqs = 1024
81
82
83
    else:
        default_llm_tokens = 8192
        default_server_tokens = 2048
84
        default_max_num_seqs = 256
85

86
    assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
87
    assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens  # noqa: E501
88
89

    engine_args = EngineArgs(model="facebook/opt-125m")
90
    vllm_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER)
91
    assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
92
    assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens  # noqa: E501