"vllm/entrypoints/openai/models/serving.py" did not exist on "cf069aa8aa38a9003c254f8434a29ec6a3070b08"
mistral_reasoning_parser.py 6.17 KB
Newer Older
Julien Denize's avatar
Julien Denize committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from collections.abc import Sequence
5
6
from functools import cached_property

7
from vllm.entrypoints.openai.chat_completion.protocol import (
8
    ChatCompletionRequest,
9
)
10
from vllm.entrypoints.openai.responses.protocol import (
11
12
    ResponsesRequest,
)
Julien Denize's avatar
Julien Denize committed
13
from vllm.logger import init_logger
14
from vllm.reasoning import ReasoningParser
15
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
16
from vllm.tokenizers.mistral import MistralTokenizer
Julien Denize's avatar
Julien Denize committed
17
18
19
20

logger = init_logger(__name__)


21
class MistralReasoningParser(BaseThinkingReasoningParser):
Julien Denize's avatar
Julien Denize committed
22
23
24
    """
    Reasoning parser for Mistral models.

25
    The Mistral models uses `[THINK]`...`[/THINK]` tokens to denote reasoning
Julien Denize's avatar
Julien Denize committed
26
    text. This parser extracts the reasoning content from the model output.
27
28
29
30
31

    A valid reasoning trace should always start with a `[THINK]` token and end with
    a `[/THINK]` token.

    If `[THINK]` token is not generated, then this parser only returns content.
Julien Denize's avatar
Julien Denize committed
32
33
    """

34
    def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
Julien Denize's avatar
Julien Denize committed
35
        if not isinstance(tokenizer, MistralTokenizer):
36
            raise ValueError("The tokenizer must be an instance of MistralTokenizer.")
Julien Denize's avatar
Julien Denize committed
37

38
        ReasoningParser.__init__(self, tokenizer, *args, **kwargs)
Julien Denize's avatar
Julien Denize committed
39
40
41
42

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
43
44
                "constructor during construction."
            )
Julien Denize's avatar
Julien Denize committed
45

46
47
        self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token)
        self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token)
Julien Denize's avatar
Julien Denize committed
48
49
50
51

        if self.start_token_id is None or self.end_token_id is None:
            raise RuntimeError(
                "Mistral reasoning parser could not locate think start/end "
52
53
                "tokens in the tokenizer!"
            )
54
55
56
57
58

    @cached_property
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        from mistral_common.tokens.tokenizers.base import SpecialTokens
59

60
61
62
63
64
65
        return SpecialTokens.begin_think

    @cached_property
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        from mistral_common.tokens.tokenizers.base import SpecialTokens
66

67
        return SpecialTokens.end_think
68

69
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        has_eot_token = False

        for id in input_ids[::-1]:
            if id == self.start_token_id:
                # Reasoning ends only if a BOT token is found before a EOT token.
                return has_eot_token
            elif id == self.end_token_id:
                has_eot_token = True
        return False

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content
        """
        has_bot_token = False
        has_eot_token = False
        bot_token_index = -1
        eot_token_index = -1
        # One for loop instead of multiple lookups
        for i, token_id in enumerate(input_ids):
            # We filter that we have multiple BOT tokens which should not
            # happen for a well prompted trained model
            if token_id == self.start_token_id and not has_bot_token:
                has_bot_token = True
                bot_token_index = i
            elif token_id == self.end_token_id:
                has_eot_token = True
                eot_token_index = i
                break

        # 1. Only BOT has been outputted
        if has_bot_token and not has_eot_token:
            # Should be = [] if model is well prompted and trained.
            return input_ids[:bot_token_index]
        # 2. Neither BOT or EOT have been outputted
        elif not has_bot_token and not has_eot_token:
            return input_ids
        # 3. Both BOT and EOT have been outputted.
        elif has_bot_token and has_eot_token:
            return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :]
110
        # 4. Only EOT has been outputted => this should not have occurred for a model
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        #    well prompted and trained.
        else:
            return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]

    def extract_reasoning(
        self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
    ) -> tuple[str | None, str | None]:
        """
        Extract reasoning content from the model output.
        """
        if not model_output:
            return (None, "")

        # Check if the start token is present in the model output, remove it
        # if it is present.
        prev_bot_token, bot_token, post_bot_token = model_output.partition(
            self.start_token
        )

        has_bot_token = bool(bot_token)
        # Valid EOT tokens should follow BOT token
        has_valid_eot_token = has_bot_token and self.end_token in post_bot_token

        # 1. If there is BOT token followed by EOT token
        if has_bot_token and has_valid_eot_token:
            prev_eot_token, _, post_eot_token = post_bot_token.partition(self.end_token)
            # If model is well prompted and trained prev_bot_token should be ""
            content = prev_bot_token + post_eot_token
            return prev_eot_token, content if content else None
        # 2. Only BOT token
        elif has_bot_token:
            # If model is well prompted and trained prev_bot_token should be ""
            return post_bot_token, prev_bot_token if prev_bot_token else None
        # 3. EOT token has been outputted without BOT or neither has been outputted
        else:
            has_non_valid_eot_token = self.end_token in prev_bot_token
            # 3.a EOT token has been outputted without BOT
            # If model is well prompted and trained `has_non_valid_eot_token` should
            # be `False` and the parser outputs all tokens as 'content'
            if has_non_valid_eot_token:
                prev_eot_token, _, post_eot_token = prev_bot_token.partition(
                    self.end_token
                )
                return None, prev_eot_token + post_eot_token
            # 3.b neither BOT or EOT have been outputted
            else:
                return None, prev_bot_token