test_serving_chat.py 5.56 KB
Newer Older
1
import asyncio
2
from contextlib import suppress
3
from dataclasses import dataclass
4
from typing import Optional
5
from unittest.mock import MagicMock
6

7
from vllm.config import MultiModalConfig
8
from vllm.engine.multiprocessing.client import MQLLMEngineClient
9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
10
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
11
12
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
                                                    OpenAIServingModels)
13
from vllm.transformers_utils.tokenizer import get_tokenizer
14
15
16

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
17
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
18
19


20
21
22
23
24
@dataclass
class MockHFConfig:
    model_type: str = "any"


25
26
@dataclass
class MockModelConfig:
27
    task = "generate"
28
29
30
31
32
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
33
    multimodal_config = MultiModalConfig()
34
    hf_config = MockHFConfig()
35
    logits_processor_pattern = None
36
    diff_sampling_param: Optional[dict] = None
37
    allowed_local_media_path: str = ""
38
    encoder_config = None
39
40
41

    def get_diff_sampling_param(self):
        return self.diff_sampling_param or {}
42
43
44
45
46
47


@dataclass
class MockEngine:

    async def get_model_config(self):
48
        return MockModelConfig()
49
50
51


async def _async_serving_chat_init():
52
53
54
    engine = MockEngine()
    model_config = await engine.get_model_config()

55
    models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
56
57
    serving_completion = OpenAIServingChat(engine,
                                           model_config,
58
                                           models,
59
                                           response_role="assistant",
60
                                           chat_template=CHAT_TEMPLATE,
61
                                           chat_template_content_format="auto",
62
                                           request_logger=None)
63
64
65
66
67
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
68
    assert serving_completion.chat_template == CHAT_TEMPLATE
69
70
71


def test_serving_chat_should_set_correct_max_tokens():
72
    mock_engine = MagicMock(spec=MQLLMEngineClient)
73
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
74
    mock_engine.errored = False
75

76
77
    models = OpenAIServingModels(engine_client=mock_engine,
                                 base_model_paths=BASE_MODEL_PATHS,
78
                                 model_config=MockModelConfig())
79
80
    serving_chat = OpenAIServingChat(mock_engine,
                                     MockModelConfig(),
81
                                     models,
82
83
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
84
                                     chat_template_content_format="auto",
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 10
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


def test_serving_chat_could_load_correct_generation_config():

    mock_model_config = MockModelConfig()
    mock_model_config.diff_sampling_param = {
        "temperature": 0.5,
        "repetition_penalty": 1.05
    }

    mock_engine = MagicMock(spec=MQLLMEngineClient)
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
    mock_engine.errored = False

    # Initialize the serving chat
120
121
    models = OpenAIServingModels(engine_client=mock_engine,
                                 base_model_paths=BASE_MODEL_PATHS,
122
                                 model_config=mock_model_config)
123
124
    serving_chat = OpenAIServingChat(mock_engine,
                                     mock_model_config,
125
                                     models,
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
                                     chat_template_content_format="auto",
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].temperature == 0.5
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test the param when user set it
    req.temperature = 0.1

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].temperature == 0.1
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

    # Test When temperature==0.0
    req.temperature = 0.0

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].temperature == 0.0
    assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05