detokenizer_manager.py 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""DetokenizerManager is a process that detokenizes the token ids."""
15

16
import dataclasses
17
import logging
18
import os
19
import signal
20
from collections import OrderedDict
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from typing import Dict, List, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
22

23
import psutil
24
import setproctitle
Lianmin Zheng's avatar
Lianmin Zheng committed
25
import zmq
Liangsheng Yin's avatar
Liangsheng Yin committed
26

Lianmin Zheng's avatar
Lianmin Zheng committed
27
from sglang.srt.hf_transformers_utils import get_tokenizer
28
29
from sglang.srt.managers.io_struct import (
    BatchEmbeddingOut,
30
    BatchMultimodalDecodeReq,
31
    BatchMultimodalOut,
32
33
    BatchStrOut,
    BatchTokenIDOut,
34
    FreezeGCReq,
35
)
Lianmin Zheng's avatar
Lianmin Zheng committed
36
from sglang.srt.server_args import PortArgs, ServerArgs
37
38
from sglang.srt.utils import (
    configure_logger,
39
    freeze_gc,
40
41
42
    get_zmq_socket,
    kill_itself_when_parent_died,
)
43
44
45
46
47
from sglang.utils import (
    TypeBasedDispatcher,
    find_printable_text,
    get_exception_traceback,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
48

49
50
logger = logging.getLogger(__name__)

51
52
53
54
55
56
# Maximum number of request states that detokenizer can hold. When exceeded,
# oldest request states will be evicted. Default: 65536 (1<<16).
# For more details, see: https://github.com/sgl-project/sglang/issues/2812
# Use power of 2 values for better memory allocation.
DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))

Lianmin Zheng's avatar
Lianmin Zheng committed
57

58
59
@dataclasses.dataclass
class DecodeStatus:
60
61
    """Store the status of incremental decoding."""

62
63
64
65
    decoded_text: str
    decode_ids: List[int]
    surr_offset: int
    read_offset: int
66
67
    # Offset that's sent to tokenizer for incremental update.
    sent_offset: int = 0
68
69


Lianmin Zheng's avatar
Lianmin Zheng committed
70
class DetokenizerManager:
71
72
    """DetokenizerManager is a process that detokenizes the token ids."""

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
75
76
77
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
78
        # Init inter-process communication
79
        context = zmq.Context(2)
80
        self.recv_from_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
81
            context, zmq.PULL, port_args.detokenizer_ipc_name, True
82
83
        )
        self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
84
            context, zmq.PUSH, port_args.tokenizer_ipc_name, False
85
        )
86

87
88
89
90
91
92
93
        if server_args.skip_tokenizer_init:
            self.tokenizer = None
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
94
                revision=server_args.revision,
95
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
96

97
        self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
98
        self.is_dummy = server_args.load_format == "dummy"
99

100
101
102
103
        self._request_dispatcher = TypeBasedDispatcher(
            [
                (BatchEmbeddingOut, self.handle_batch_embedding_out),
                (BatchTokenIDOut, self.handle_batch_token_id_out),
104
                (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
105
                (FreezeGCReq, self.handle_freeze_gc_req),
106
107
108
            ]
        )

109
110
111
    def event_loop(self):
        """The event loop that handles requests"""
        while True:
112
113
            recv_obj = self.recv_from_scheduler.recv_pyobj()
            output = self._request_dispatcher(recv_obj)
114
115
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
116

Lianmin Zheng's avatar
Lianmin Zheng committed
117
118
119
120
121
122
123
124
    def trim_matched_stop(
        self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
    ):
        if no_stop_trim or not finished_reason:
            return output

        matched = finished_reason.get("matched", None)
        if not matched:
125
126
            return output

Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
129
130
131
        # TODO(lmzheng): handle the case where multiple stop strs are hit

        # Trim stop str.
        if isinstance(matched, str) and isinstance(output, str):
            pos = output.find(matched)
132
            return output[:pos] if pos != -1 else output
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
135

        # Trim stop token.
        if isinstance(matched, int) and isinstance(output, list):
136
137
138
139
            assert len(output) > 0
            return output[:-1]
        return output

140
141
142
143
144
145
146
147
148
149
150
    def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
        # If it is embedding model, no detokenization is needed.
        return recv_obj

    def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
        bs = len(recv_obj.rids)

        # Initialize decode status
        read_ids, surr_ids = [], []
        for i in range(bs):
            rid = recv_obj.rids[i]
151
            if rid not in self.decode_status:
152
153
154
155
156
157
158
                s = DecodeStatus(
                    decoded_text=recv_obj.decoded_texts[i],
                    decode_ids=recv_obj.decode_ids[i],
                    surr_offset=0,
                    read_offset=recv_obj.read_offsets[i],
                )
                self.decode_status[rid] = s
159
            else:
160
                s = self.decode_status[rid]
161
                s.decode_ids.extend(recv_obj.decode_ids[i])
162
163
164
165
166
167

            read_ids.append(
                self.trim_matched_stop(
                    s.decode_ids[s.surr_offset :],
                    recv_obj.finished_reasons[i],
                    recv_obj.no_stop_trim[i],
168
                )
169
            )
170
            surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
        surr_texts = self.tokenizer.batch_decode(
            surr_ids,
            skip_special_tokens=recv_obj.skip_special_tokens[0],
            spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
        )
        read_texts = self.tokenizer.batch_decode(
            read_ids,
            skip_special_tokens=recv_obj.skip_special_tokens[0],
            spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
        )

        # Incremental decoding
        output_strs = []
        for i in range(bs):
            try:
                s = self.decode_status[recv_obj.rids[i]]
            except KeyError:
                raise RuntimeError(
                    f"Decode status not found for request {recv_obj.rids[i]}. "
                    "It may be due to the request being evicted from the decode status due to memory pressure. "
                    "Please increase the maximum number of requests by setting "
                    "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
                    f"The current value is {DETOKENIZER_MAX_STATES}. "
                    "For more details, see: https://github.com/sgl-project/sglang/issues/2812"
197
                )
198
199
200
201
202
203
204
205
206
207
            new_text = read_texts[i][len(surr_texts[i]) :]
            if recv_obj.finished_reasons[i] is None:
                # Streaming chunk: update the decode status
                if len(new_text) > 0 and not new_text.endswith("�"):
                    s.decoded_text = s.decoded_text + new_text
                    s.surr_offset = s.read_offset
                    s.read_offset = len(s.decode_ids)
                    new_text = ""
                else:
                    new_text = find_printable_text(new_text)
208

209
210
211
212
            output_str = self.trim_matched_stop(
                s.decoded_text + new_text,
                recv_obj.finished_reasons[i],
                recv_obj.no_stop_trim[i],
213
            )
214
215
216
217
            # Incrementally send text.
            incremental_output = output_str[s.sent_offset :]
            s.sent_offset = len(output_str)
            output_strs.append(incremental_output)
Lianmin Zheng's avatar
Lianmin Zheng committed
218

219
        return BatchStrOut(
220
221
222
            rids=recv_obj.rids,
            finished_reasons=recv_obj.finished_reasons,
            output_strs=output_strs,
223
            output_ids=recv_obj.decode_ids,
224
225
226
227
228
229
230
231
232
233
234
235
            prompt_tokens=recv_obj.prompt_tokens,
            completion_tokens=recv_obj.completion_tokens,
            cached_tokens=recv_obj.cached_tokens,
            spec_verify_ct=recv_obj.spec_verify_ct,
            input_token_logprobs_val=recv_obj.input_token_logprobs_val,
            input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
            output_token_logprobs_val=recv_obj.output_token_logprobs_val,
            output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
            input_top_logprobs_val=recv_obj.input_top_logprobs_val,
            input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
            output_top_logprobs_val=recv_obj.output_top_logprobs_val,
            output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
236
237
238
239
            input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
            input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
            output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
            output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
240
241
242
            output_hidden_states=recv_obj.output_hidden_states,
        )

243
    def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
244
245
246
247
248
249
250
251
252
        outputs = self.tokenizer.detokenize(recv_obj)
        return BatchMultimodalOut(
            rids=recv_obj.rids,
            finished_reasons=recv_obj.finished_reasons,
            outputs=outputs,
            prompt_tokens=recv_obj.prompt_tokens,
            completion_tokens=recv_obj.completion_tokens,
            cached_tokens=recv_obj.cached_tokens,
        )
253

254
255
256
257
    def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
        freeze_gc("Detokenizer Manager")
        return None

Lianmin Zheng's avatar
Lianmin Zheng committed
258

259
class LimitedCapacityDict(OrderedDict):
260
    def __init__(self, capacity: int, *args, **kwargs):
261
262
263
264
265
266
267
268
269
270
271
272
        super().__init__(*args, **kwargs)
        self.capacity = capacity

    def __setitem__(self, key, value):
        if len(self) >= self.capacity:
            # Remove the oldest element (first item in the dict)
            self.popitem(last=False)
        # Set the new item
        super().__setitem__(key, value)


def run_detokenizer_process(
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
275
    server_args: ServerArgs,
    port_args: PortArgs,
):
276
    kill_itself_when_parent_died()
277
    setproctitle.setproctitle("sglang::detokenizer")
278
    configure_logger(server_args)
279
    parent_process = psutil.Process().parent()
280

Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
    try:
        manager = DetokenizerManager(server_args, port_args)
283
        manager.event_loop()
284
    except Exception:
285
286
287
        traceback = get_exception_traceback()
        logger.error(f"DetokenizerManager hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)