"tests/vscode:/vscode.git/clone" did not exist on "9556af87d5d5a38128db0d09eeb7f2fe16f16589"
test_mistral.py 4.14 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
import time
6
7
from dataclasses import dataclass
from typing import Any
8
9
10
11
12
from unittest.mock import Mock

import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

13
from vllm.renderers import ChatParams
14
15
16
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    max_model_len: int = 100
    tokenizer_revision = None
    tokenizer_mode = "mistral"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False
    is_encoder_decoder: bool = False
39
    is_multimodal_model: bool = False
40
    renderer_num_workers: int = 1
41

42

43
44
45
46
47
@dataclass
class MockParallelConfig:
    _api_process_rank: int = 0


48
49
50
@dataclass
class MockVllmConfig:
    model_config: MockModelConfig
51
    parallel_config: MockParallelConfig
52
53


54
55
56
57
58
59
60
61
62
@pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop():
    expected_tokens = [1, 2, 3]

    # Mock the blocking version to sleep
    def mocked_apply_chat_template(*_args, **_kwargs):
        time.sleep(2)
        return expected_tokens

63
    mock_model_config = MockModelConfig(skip_tokenizer_init=True)
64
65
    mock_tokenizer = Mock(spec=MistralTokenizer)
    mock_tokenizer.apply_chat_template = mocked_apply_chat_template
66
    mock_renderer = MistralRenderer(
67
        MockVllmConfig(mock_model_config, parallel_config=MockParallelConfig()),
68
        tokenizer=mock_tokenizer,
69
    )
70

71
    task = mock_renderer.render_messages_async([], ChatParams())
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    # Ensure the event loop is not blocked
    blocked_count = 0
    for _i in range(20):  # Check over ~2 seconds
        start = time.perf_counter()
        await asyncio.sleep(0)
        elapsed = time.perf_counter() - start

        # an overly generous elapsed time for slow machines
        if elapsed >= 0.5:
            blocked_count += 1

        await asyncio.sleep(0.1)

    # Ensure task completes
    _, prompt = await task
    assert prompt["prompt_token_ids"] == expected_tokens, (
        "Mocked blocking tokenizer was not called"
    )
    assert blocked_count == 0, "Event loop blocked during tokenization"


def test_apply_mistral_chat_template_thinking_chunk():
    messages = [
        {
            "role": "system",
            "content": [
                {"type": "text", "text": "You are a helpful assistant."},
                {
                    "type": "thinking",
                    "closed": True,
                    "thinking": "Only return the answer when you are confident.",
                },
            ],
        },
        {"role": "user", "content": "What is 2+2?"},
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": "Let me think about it."},
                {"type": "thinking", "closed": True, "thinking": "2+2 = 4"},
                {
                    "type": "text",
                    "text": "The answer is 4.",
                },
            ],
        },
        {"role": "user", "content": "Thanks, what is 3+3?"},
    ]
    mistral_tokenizer = MistralTokenizer.from_pretrained(
        "mistralai/Magistral-Small-2509"
    )

    tokens_ids = safe_apply_chat_template(
        mistral_tokenizer, messages, chat_template=None, tools=None
    )

    string_tokens = mistral_tokenizer.mistral.decode(
        tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP
    )

    expected_tokens = (
        r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
        r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
        r"[INST]What is 2+2?[/INST]"
        r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
        r"[INST]Thanks, what is 3+3?[/INST]"
    )

    assert string_tokens == expected_tokens