test_serving_chat.py 3.24 KB
Newer Older
1
import asyncio
2
from contextlib import suppress
3
from dataclasses import dataclass
4
from unittest.mock import MagicMock
5

6
from vllm.config import MultiModalConfig
7
from vllm.engine.multiprocessing.client import MQLLMEngineClient
8
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
9
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
10
from vllm.entrypoints.openai.serving_engine import BaseModelPath
11
from vllm.transformers_utils.tokenizer import get_tokenizer
12
13
14

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
15
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
16
17


18
19
20
21
22
@dataclass
class MockHFConfig:
    model_type: str = "any"


23
24
@dataclass
class MockModelConfig:
25
    task = "generate"
26
27
28
29
30
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
31
    multimodal_config = MultiModalConfig()
32
    hf_config = MockHFConfig()
33
    logits_processor_pattern = None
34
35
36
37
38
39


@dataclass
class MockEngine:

    async def get_model_config(self):
40
        return MockModelConfig()
41
42
43


async def _async_serving_chat_init():
44
45
46
47
48
    engine = MockEngine()
    model_config = await engine.get_model_config()

    serving_completion = OpenAIServingChat(engine,
                                           model_config,
49
                                           BASE_MODEL_PATHS,
50
                                           response_role="assistant",
51
                                           chat_template=CHAT_TEMPLATE,
52
                                           chat_template_content_format="auto",
53
54
55
                                           lora_modules=None,
                                           prompt_adapters=None,
                                           request_logger=None)
56
57
58
59
60
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
61
    assert serving_completion.chat_template == CHAT_TEMPLATE
62
63
64


def test_serving_chat_should_set_correct_max_tokens():
65
    mock_engine = MagicMock(spec=MQLLMEngineClient)
66
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
67
    mock_engine.errored = False
68
69
70

    serving_chat = OpenAIServingChat(mock_engine,
                                     MockModelConfig(),
71
                                     BASE_MODEL_PATHS,
72
73
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
74
                                     chat_template_content_format="auto",
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                                     lora_modules=None,
                                     prompt_adapters=None,
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 10