test_sampling_params.py 2.74 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the SamplingParams class.
"""
5

6
import os
7
8
import pytest

9
from vllm import SamplingParams
10
11
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
12
from utils import models_path_prefix
13

14
MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen1.5-7B")
15
16
17
18
19
20
21


def test_max_tokens_none():
    """max_tokens=None should be allowed"""
    SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@pytest.fixture(scope="module")
def model_config():
    return ModelConfig(
        MODEL_NAME,
        task="auto",
        tokenizer=MODEL_NAME,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
    )


@pytest.fixture(scope="module")
def default_max_tokens():
    return 4096


zhuwenwen's avatar
zhuwenwen committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# def test_sampling_params_from_request_with_no_guided_decoding_backend(
#         model_config, default_max_tokens):
#     # guided_decoding_backend is not present at request level
#     request = ChatCompletionRequest.model_validate({
#         'messages': [{
#             'role': 'user',
#             'content': 'Hello'
#         }],
#         'model':
#         MODEL_NAME,
#         'response_format': {
#             'type': 'json_object',
#         },
#     })

#     sampling_params = request.to_sampling_params(
#         default_max_tokens,
#         model_config.logits_processor_pattern,
#     )
#     # we do not expect any backend to be present and the default
#     # guided_decoding_backend at engine level will be used.
#     assert sampling_params.guided_decoding.backend is None


# @pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
#                          [("xgrammar", "xgrammar"),
#                           ("lm-format-enforcer", "lm-format-enforcer"),
#                           ("outlines", "outlines")])
# def test_sampling_params_from_request_with_guided_decoding_backend(
#         request_level_guided_decoding_backend: str, expected: str,
#         model_config, default_max_tokens):

#     request = ChatCompletionRequest.model_validate({
#         'messages': [{
#             'role': 'user',
#             'content': 'Hello'
#         }],
#         'model':
#         MODEL_NAME,
#         'response_format': {
#             'type': 'json_object',
#         },
#         'guided_decoding_backend':
#         request_level_guided_decoding_backend,
#     })

#     sampling_params = request.to_sampling_params(
#         default_max_tokens,
#         model_config.logits_processor_pattern,
#     )
#     # backend correctly identified in resulting sampling_params
#     assert sampling_params.guided_decoding.backend == expected