utils.py 5.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
6
from vllm.reasoning import ReasoningParser
7
from vllm.tokenizers.mistral import MistralTokenizer
8
9
10
11


class StreamingReasoningReconstructor:
    def __init__(self):
12
        self.reasoning = None
13
14
15
16
17
        self.other_content = None

    def append_delta(self, delta: DeltaMessage):
        # content and the reasoning content should not be present
        # at the same time
18
        assert delta.content is None or delta.reasoning is None, (
19
20
            "Both content and reasoning content are present in the delta message"
        )
21
22
23
        assert delta.reasoning == delta.reasoning_content, (
            "reasoning_content should be present for backwards compatibility"
        )
24
25
26
27
28
29
        if delta.content is not None:
            if self.other_content is None:
                self.other_content = delta.content
            else:
                self.other_content += delta.content
        else:
30
31
            if self.reasoning is None:
                self.reasoning = delta.reasoning
32
            else:
33
                self.reasoning += delta.reasoning
34
35
36
37


def run_reasoning_extraction(
    reasoning_parser: ReasoningParser,
38
    model_output: list[str],
39
    request: ChatCompletionRequest | None = None,
40
    streaming: bool = False,
41
) -> tuple[str | None, str | None]:
42
43
44
45
46
47
48
    if streaming:
        reconstructor = run_reasoning_extraction_streaming(
            reasoning_parser,
            model_output,
            request,
        )
        return (
49
            reconstructor.reasoning,
50
51
52
53
            reconstructor.other_content or None,
        )
    else:
        reasoning, content = run_reasoning_extraction_nonstreaming(
54
55
            reasoning_parser, model_output, request
        )
56
57
58
        return reasoning, content


Julien Denize's avatar
Julien Denize committed
59
60
61
def run_reasoning_extraction_mistral(
    reasoning_parser: ReasoningParser,
    model_output: list[int],
62
    request: ChatCompletionRequest | None = None,
Julien Denize's avatar
Julien Denize committed
63
    streaming: bool = False,
64
) -> tuple[str | None, str | None]:
65
66
67
    assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type(
        reasoning_parser.model_tokenizer
    )
Julien Denize's avatar
Julien Denize committed
68
69
70
71
72
73
74
    if streaming:
        reconstructor = run_reasoning_extraction_streaming_mistral(
            reasoning_parser,
            model_output,
            request,
        )
        return (
75
            reconstructor.reasoning,
Julien Denize's avatar
Julien Denize committed
76
77
78
79
            reconstructor.other_content or None,
        )
    else:
        str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
80
81
            model_output
        )
Julien Denize's avatar
Julien Denize committed
82
        reasoning, content = run_reasoning_extraction_nonstreaming(
83
84
            reasoning_parser, str_output, request
        )
Julien Denize's avatar
Julien Denize committed
85
86
87
        return reasoning, content


88
89
def run_reasoning_extraction_nonstreaming(
    reasoning_parser: ReasoningParser,
90
    model_output: list[str],
91
92
    request: ChatCompletionRequest | None = None,
) -> tuple[str | None, str | None]:
93
    request = request or ChatCompletionRequest(messages=[], model="test-model")
94
    return reasoning_parser.extract_reasoning(
95
96
        model_output="".join(model_output), request=request
    )
97
98
99
100


def run_reasoning_extraction_streaming(
    reasoning_parser: ReasoningParser,
101
    model_deltas: list[str],
102
    request: ChatCompletionRequest | None = None,
103
104
105
106
) -> StreamingReasoningReconstructor:
    request = request or ChatCompletionRequest(messages=[], model="test-model")
    reconstructor = StreamingReasoningReconstructor()
    previous_text = ""
107
    previous_tokens: list[int] = []
108
109
110
111
112
113
114
115
    for delta in model_deltas:
        token_delta = [
            reasoning_parser.vocab.get(token)
            for token in reasoning_parser.model_tokenizer.tokenize(delta)
            if token in reasoning_parser.vocab
        ]
        current_text = previous_text + delta
        current_tokens = previous_tokens + token_delta
116
        delta_message = reasoning_parser.extract_reasoning_streaming(
117
118
119
120
121
122
123
124
125
126
127
128
            previous_text,
            current_text,
            delta,
            previous_tokens,
            current_tokens,
            token_delta,
        )
        if delta_message is not None:
            reconstructor.append_delta(delta_message)
        previous_text = current_text
        previous_tokens = current_tokens
    return reconstructor
Julien Denize's avatar
Julien Denize committed
129
130
131
132
133


def run_reasoning_extraction_streaming_mistral(
    reasoning_parser: ReasoningParser,
    model_deltas: list[int],
134
    request: ChatCompletionRequest | None = None,
Julien Denize's avatar
Julien Denize committed
135
) -> StreamingReasoningReconstructor:
136
137
138
    assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type(
        reasoning_parser.model_tokenizer
    )
Julien Denize's avatar
Julien Denize committed
139
140
141
142
143
144
    request = request or ChatCompletionRequest(messages=[], model="test-model")
    reconstructor = StreamingReasoningReconstructor()
    previous_text = ""
    previous_tokens: list[int] = []
    for model_delta in model_deltas:
        token_delta = [model_delta]
145
        delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens([model_delta])[0]
Julien Denize's avatar
Julien Denize committed
146
147
        current_text = previous_text + delta
        current_tokens = previous_tokens + token_delta
148
        delta_message = reasoning_parser.extract_reasoning_streaming(
Julien Denize's avatar
Julien Denize committed
149
150
151
152
153
154
155
156
157
158
159
160
            previous_text,
            current_text,
            delta,
            previous_tokens,
            current_tokens,
            token_delta,
        )
        if delta_message is not None:
            reconstructor.append_delta(delta_message)
        previous_text = current_text
        previous_tokens = current_tokens
    return reconstructor