"vscode:/vscode.git/clone" did not exist on "bf4a901af91e431a4cdb51ed4557b31bf89c0e5d"
test_serving_chat.py 5.35 KB
Newer Older
1
import asyncio
2
from contextlib import suppress
3
from dataclasses import dataclass
4
from typing import Optional
5
from unittest.mock import MagicMock
6

7
from vllm.config import MultiModalConfig
8
from vllm.engine.multiprocessing.client import MQLLMEngineClient
9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
10
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
11
from vllm.entrypoints.openai.serving_engine import BaseModelPath
12
from vllm.transformers_utils.tokenizer import get_tokenizer
13
14
15

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
16
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
17
18


19
20
21
22
23
@dataclass
class MockHFConfig:
    model_type: str = "any"


24
25
@dataclass
class MockModelConfig:
26
    task = "generate"
27
28
29
30
31
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
32
    multimodal_config = MultiModalConfig()
33
    hf_config = MockHFConfig()
34
    logits_processor_pattern = None
35
36
37
38
    diff_sampling_param: Optional[dict] = None

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
39
40
41
42
43
44


@dataclass
class MockEngine:

    async def get_model_config(self):
45
        return MockModelConfig()
46
47
48


async def _async_serving_chat_init():
49
50
51
52
53
    engine = MockEngine()
    model_config = await engine.get_model_config()

    serving_completion = OpenAIServingChat(engine,
                                           model_config,
54
                                           BASE_MODEL_PATHS,
55
                                           response_role="assistant",
56
                                           chat_template=CHAT_TEMPLATE,
57
                                           chat_template_content_format="auto",
58
59
60
                                           lora_modules=None,
                                           prompt_adapters=None,
                                           request_logger=None)
61
62
63
64
65
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
66
    assert serving_completion.chat_template == CHAT_TEMPLATE
67
68
69


def test_serving_chat_should_set_correct_max_tokens():
70
    mock_engine = MagicMock(spec=MQLLMEngineClient)
71
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
72
    mock_engine.errored = False
73
74
75

    serving_chat = OpenAIServingChat(mock_engine,
                                     MockModelConfig(),
76
                                     BASE_MODEL_PATHS,
77
78
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
79
                                     chat_template_content_format="auto",
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
                                     lora_modules=None,
                                     prompt_adapters=None,
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157


def test_serving_chat_could_load_correct_generation_config():

    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
        "repetition_penalty": 1.05
    }

    mock_engine = MagicMock(spec=MQLLMEngineClient)
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

    # Initialize the serving chat
    serving_chat = OpenAIServingChat(mock_engine,
                                     mock_model_config,
                                     BASE_MODEL_PATHS,
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
                                     chat_template_content_format="auto",
                                     lora_modules=None,
                                     prompt_adapters=None,
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05