test_serving_chat.py 3.09 KB
Newer Older
1
import asyncio
2
from contextlib import suppress
3
from dataclasses import dataclass
4
from unittest.mock import MagicMock
5

6
from vllm.config import MultiModalConfig
7
from vllm.engine.multiprocessing.client import MQLLMEngineClient
8
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
9
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
10
from vllm.entrypoints.openai.serving_engine import BaseModelPath
11
from vllm.transformers_utils.tokenizer import get_tokenizer
12
13
14

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
15
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
16
17


18
19
20
21
22
@dataclass
class MockHFConfig:
    model_type: str = "any"


23
24
@dataclass
class MockModelConfig:
25
    task = "generate"
26
27
28
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
29
    chat_template_text_format = "string"
30
31
    max_model_len = 100
    tokenizer_revision = None
32
    multimodal_config = MultiModalConfig()
33
    hf_config = MockHFConfig()
34
35
36
37
38
39


@dataclass
class MockEngine:

    async def get_model_config(self):
40
        return MockModelConfig()
41
42
43


async def _async_serving_chat_init():
44
45
46
47
48
    engine = MockEngine()
    model_config = await engine.get_model_config()

    serving_completion = OpenAIServingChat(engine,
                                           model_config,
49
                                           BASE_MODEL_PATHS,
50
                                           response_role="assistant",
51
52
53
54
                                           chat_template=CHAT_TEMPLATE,
                                           lora_modules=None,
                                           prompt_adapters=None,
                                           request_logger=None)
55
56
57
58
59
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
60
    assert serving_completion.chat_template == CHAT_TEMPLATE
61
62
63


def test_serving_chat_should_set_correct_max_tokens():
64
    mock_engine = MagicMock(spec=MQLLMEngineClient)
65
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
66
    mock_engine.errored = False
67
68
69

    serving_chat = OpenAIServingChat(mock_engine,
                                     MockModelConfig(),
70
                                     BASE_MODEL_PATHS,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
                                     lora_modules=None,
                                     prompt_adapters=None,
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 10