test_mistral_tokenizer.py 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# SPDX-License-Identifier: Apache-2.0

import pytest
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool

from vllm.transformers_utils.tokenizers.mistral import (
    make_mistral_chat_completion_request)


# yapf: enable
@pytest.mark.parametrize(
    "openai_request,expected_mistral_request",
    [(
        {
            "messages": [{
                "role": "user",
                "content": "What is the current local date and time?",
            }],
            "tools": [{
                "type": "function",
                "function": {
                    "description": "Fetch the current local date and time.",
                    "name": "get_current_time",
                },
            }],
        },
        ChatCompletionRequest(
            messages=[
                UserMessage(content="What is the current local date and time?")
            ],
            tools=[
                Tool(
                    type="function",
                    function=Function(
                        name="get_current_time",
                        description="Fetch the current local date and time.",
                        parameters={},
                    ),
                )
            ],
        ),
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    ),
     (
         {
             "messages":
             [{
                 "role": "user",
                 "content": "What is the current local date and time?",
             }],
             "tools": [{
                 "type": "function",
                 "function": {
                     "description": "Fetch the current local date and time.",
                     "name": "get_current_time",
                     "parameters": None,
                 },
             }],
         },
         ChatCompletionRequest(
             messages=[
                 UserMessage(
                     content="What is the current local date and time?")
             ],
             tools=[
                 Tool(
                     type="function",
                     function=Function(
                         name="get_current_time",
                         description="Fetch the current local date and time.",
                         parameters={},
                     ),
                 )
             ],
         ),
     )],
78
79
80
81
82
83
)
def test_make_mistral_chat_completion_request(openai_request,
                                              expected_mistral_request):
    assert (make_mistral_chat_completion_request(
        openai_request["messages"],
        openai_request["tools"]) == expected_mistral_request)