test_serving_chat.py 3.05 KB
Newer Older
1
import asyncio
2
from contextlib import suppress
3
from dataclasses import dataclass
4
from unittest.mock import MagicMock
5

6
from vllm.config import MultiModalConfig
7
from vllm.engine.multiprocessing.client import MQLLMEngineClient
8
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
9
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
10
from vllm.entrypoints.openai.serving_engine import BaseModelPath
11
from vllm.transformers_utils.tokenizer import get_tokenizer
12
13
14

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
15
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
16
17


18
19
20
21
22
@dataclass
class MockHFConfig:
    model_type: str = "any"


23
24
@dataclass
class MockModelConfig:
25
    task = "generate"
26
27
28
29
30
    tokenizer = MODEL_NAME
    trust_remote_code = False
    tokenizer_mode = "auto"
    max_model_len = 100
    tokenizer_revision = None
31
    multimodal_config = MultiModalConfig()
32
    hf_config = MockHFConfig()
33
34
35
36
37
38


@dataclass
class MockEngine:

    async def get_model_config(self):
39
        return MockModelConfig()
40
41
42


async def _async_serving_chat_init():
43
44
45
46
47
    engine = MockEngine()
    model_config = await engine.get_model_config()

    serving_completion = OpenAIServingChat(engine,
                                           model_config,
48
                                           BASE_MODEL_PATHS,
49
                                           response_role="assistant",
50
51
52
53
                                           chat_template=CHAT_TEMPLATE,
                                           lora_modules=None,
                                           prompt_adapters=None,
                                           request_logger=None)
54
55
56
57
58
    return serving_completion


def test_async_serving_chat_init():
    serving_completion = asyncio.run(_async_serving_chat_init())
59
    assert serving_completion.chat_template == CHAT_TEMPLATE
60
61
62


def test_serving_chat_should_set_correct_max_tokens():
63
    mock_engine = MagicMock(spec=MQLLMEngineClient)
64
    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
65
    mock_engine.errored = False
66
67
68

    serving_chat = OpenAIServingChat(mock_engine,
                                     MockModelConfig(),
69
                                     BASE_MODEL_PATHS,
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                                     response_role="assistant",
                                     chat_template=CHAT_TEMPLATE,
                                     lora_modules=None,
                                     prompt_adapters=None,
                                     request_logger=None)
    req = ChatCompletionRequest(
        model=MODEL_NAME,
        messages=[{
            "role": "user",
            "content": "what is 1+1?"
        }],
        guided_decoding_backend="outlines",
    )

    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 93

    req.max_tokens = 10
    with suppress(Exception):
        asyncio.run(serving_chat.create_chat_completion(req))

    assert mock_engine.generate.call_args.args[1].max_tokens == 10