test_serving_chat.py 2.74 KB
Newer Older
1
import asyncio
2
from contextlib import suppress
3
from dataclasses import dataclass
4
from unittest.mock import MagicMock
5

6
7
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
8
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
9
from vllm.transformers_utils.tokenizer import get_tokenizer
10
11
12
13
14
15
16
17
18
19
20
21

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"


@dataclass
class MockModelConfig:
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
22
    embedding_mode = False
23
24
25
26
27
28


@dataclass
class MockEngine:

    async def get_model_config(self):
29
        return MockModelConfig()
30
31
32


async def _async_serving_chat_init():
33
34
35
36
37
    engine = MockEngine()
    model_config = await engine.get_model_config()

    serving_completion = OpenAIServingChat(engine,
                                           model_config,
38
39
                                           served_model_names=[MODEL_NAME],
                                           response_role="assistant",
40
41
42
43
                                           chat_template=CHAT_TEMPLATE,
                                           lora_modules=None,
                                           prompt_adapters=None,
                                           request_logger=None)
44
45
46
47
48
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
49
    assert serving_completion.chat_template == CHAT_TEMPLATE
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


def test_serving_chat_should_set_correct_max_tokens():
    mock_engine = MagicMock(spec=AsyncLLMEngine)
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)

    serving_chat = OpenAIServingChat(mock_engine,
                                     MockModelConfig(),
                                     served_model_names=[MODEL_NAME],
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
                                     lora_modules=None,
                                     prompt_adapters=None,
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 10