mistral_reasoning_parser.py 6.12 KB
Newer Older
Julien Denize's avatar
Julien Denize committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from functools import cached_property

6
from vllm.entrypoints.openai.chat_completion.protocol import (
7
    ChatCompletionRequest,
8
9
)
from vllm.entrypoints.openai.engine.protocol import (
10
11
    ResponsesRequest,
)
Julien Denize's avatar
Julien Denize committed
12
from vllm.logger import init_logger
13
from vllm.reasoning import ReasoningParser
14
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
15
from vllm.tokenizers.mistral import MistralTokenizer
Julien Denize's avatar
Julien Denize committed
16
17
18
19

logger = init_logger(__name__)


20
class MistralReasoningParser(BaseThinkingReasoningParser):
Julien Denize's avatar
Julien Denize committed
21
22
23
    """
    Reasoning parser for Mistral models.

24
    The Mistral models uses `[THINK]`...`[/THINK]` tokens to denote reasoning
Julien Denize's avatar
Julien Denize committed
25
    text. This parser extracts the reasoning content from the model output.
26
27
28
29
30

    A valid reasoning trace should always start with a `[THINK]` token and end with
    a `[/THINK]` token.

    If `[THINK]` token is not generated, then this parser only returns content.
Julien Denize's avatar
Julien Denize committed
31
32
    """

33
    def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
Julien Denize's avatar
Julien Denize committed
34
        if not isinstance(tokenizer, MistralTokenizer):
35
            raise ValueError("The tokenizer must be an instance of MistralTokenizer.")
Julien Denize's avatar
Julien Denize committed
36

37
        ReasoningParser.__init__(self, tokenizer, *args, **kwargs)
Julien Denize's avatar
Julien Denize committed
38
39
40
41

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
42
43
                "constructor during construction."
            )
Julien Denize's avatar
Julien Denize committed
44

45
46
        self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token)
        self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token)
Julien Denize's avatar
Julien Denize committed
47
48
49
50

        if self.start_token_id is None or self.end_token_id is None:
            raise RuntimeError(
                "Mistral reasoning parser could not locate think start/end "
51
52
                "tokens in the tokenizer!"
            )
53
54
55
56
57

    @cached_property
    def start_token(self) -> str:
        """The token that starts reasoning content."""
        from mistral_common.tokens.tokenizers.base import SpecialTokens
58

59
60
61
62
63
64
        return SpecialTokens.begin_think

    @cached_property
    def end_token(self) -> str:
        """The token that ends reasoning content."""
        from mistral_common.tokens.tokenizers.base import SpecialTokens
65

66
        return SpecialTokens.end_think
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        has_eot_token = False

        for id in input_ids[::-1]:
            if id == self.start_token_id:
                # Reasoning ends only if a BOT token is found before a EOT token.
                return has_eot_token
            elif id == self.end_token_id:
                has_eot_token = True
        return False

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content
        """
        has_bot_token = False
        has_eot_token = False
        bot_token_index = -1
        eot_token_index = -1
        # One for loop instead of multiple lookups
        for i, token_id in enumerate(input_ids):
            # We filter that we have multiple BOT tokens which should not
            # happen for a well prompted trained model
            if token_id == self.start_token_id and not has_bot_token:
                has_bot_token = True
                bot_token_index = i
            elif token_id == self.end_token_id:
                has_eot_token = True
                eot_token_index = i
                break

        # 1. Only BOT has been outputted
        if has_bot_token and not has_eot_token:
            # Should be = [] if model is well prompted and trained.
            return input_ids[:bot_token_index]
        # 2. Neither BOT or EOT have been outputted
        elif not has_bot_token and not has_eot_token:
            return input_ids
        # 3. Both BOT and EOT have been outputted.
        elif has_bot_token and has_eot_token:
            return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :]
109
        # 4. Only EOT has been outputted => this should not have occurred for a model
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        #    well prompted and trained.
        else:
            return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]

    def extract_reasoning(
        self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
    ) -> tuple[str | None, str | None]:
        """
        Extract reasoning content from the model output.
        """
        if not model_output:
            return (None, "")

        # Check if the start token is present in the model output, remove it
        # if it is present.
        prev_bot_token, bot_token, post_bot_token = model_output.partition(
            self.start_token
        )

        has_bot_token = bool(bot_token)
        # Valid EOT tokens should follow BOT token
        has_valid_eot_token = has_bot_token and self.end_token in post_bot_token

        # 1. If there is BOT token followed by EOT token
        if has_bot_token and has_valid_eot_token:
            prev_eot_token, _, post_eot_token = post_bot_token.partition(self.end_token)
            # If model is well prompted and trained prev_bot_token should be ""
            content = prev_bot_token + post_eot_token
            return prev_eot_token, content if content else None
        # 2. Only BOT token
        elif has_bot_token:
            # If model is well prompted and trained prev_bot_token should be ""
            return post_bot_token, prev_bot_token if prev_bot_token else None
        # 3. EOT token has been outputted without BOT or neither has been outputted
        else:
            has_non_valid_eot_token = self.end_token in prev_bot_token
            # 3.a EOT token has been outputted without BOT
            # If model is well prompted and trained `has_non_valid_eot_token` should
            # be `False` and the parser outputs all tokens as 'content'
            if has_non_valid_eot_token:
                prev_eot_token, _, post_eot_token = prev_bot_token.partition(
                    self.end_token
                )
                return None, prev_eot_token + post_eot_token
            # 3.b neither BOT or EOT have been outputted
            else:
                return None, prev_bot_token