detokenizer_manager.py 7.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""DetokenizerManager is a process that detokenizes the token ids."""
15

16
import dataclasses
17
18
import logging
from collections import OrderedDict
19
from typing import List, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21

import zmq
Liangsheng Yin's avatar
Liangsheng Yin committed
22

Lianmin Zheng's avatar
Lianmin Zheng committed
23
from sglang.srt.hf_transformers_utils import get_tokenizer
24
25
26
27
28
from sglang.srt.managers.io_struct import (
    BatchEmbeddingOut,
    BatchStrOut,
    BatchTokenIDOut,
)
29
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
Lianmin Zheng's avatar
Lianmin Zheng committed
30
from sglang.srt.server_args import PortArgs, ServerArgs
31
from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process
32
from sglang.utils import find_printable_text, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
35
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
38
@dataclasses.dataclass
class DecodeStatus:
39
40
    """Store the status of incremental decoding."""

41
    vid: int
42
43
44
45
46
47
    decoded_text: str
    decode_ids: List[int]
    surr_offset: int
    read_offset: int


Lianmin Zheng's avatar
Lianmin Zheng committed
48
class DetokenizerManager:
49
50
    """DetokenizerManager is a process that detokenizes the token ids."""

Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
54
55
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
56
        # Init inter-process communication
57
        context = zmq.Context(2)
58
59
60
61
62
63
        self.recv_from_scheduler = get_zmq_socket(
            context, zmq.PULL, port_args.detokenizer_ipc_name
        )
        self.send_to_tokenizer = get_zmq_socket(
            context, zmq.PUSH, port_args.tokenizer_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
64

65
66
67
68
69
70
71
72
        if server_args.skip_tokenizer_init:
            self.tokenizer = None
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
73

74
        self.decode_status = LimitedCapacityDict()
75

76
77
    def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
        if no_stop_trim:
78
79
80
81
82
83
84
85
86
87
88
89
90
            return output

        # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
        if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
            pos = output.find(finished_reason.matched)
            return output[:pos] if pos != -1 else output
        if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
            output, list
        ):
            assert len(output) > 0
            return output[:-1]
        return output

91
    def event_loop(self):
92
93
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
94
        while True:
95
            recv_obj = self.recv_from_scheduler.recv_pyobj()
96
97

            if isinstance(recv_obj, BatchEmbeddingOut):
98
                # If it is embedding model, no detokenization is needed.
99
                self.send_to_tokenizer.send_pyobj(recv_obj)
100
                continue
101
102
            else:
                assert isinstance(recv_obj, BatchTokenIDOut)
103

104
105
106
107
108
109
            bs = len(recv_obj.rids)

            # Initialize decode status
            read_ids, surr_ids = [], []
            for i in range(bs):
                rid = recv_obj.rids[i]
110
111
                vid = recv_obj.vids[i]
                if rid not in self.decode_status or self.decode_status[rid].vid != vid:
112
                    s = DecodeStatus(
113
                        vid=vid,
114
115
116
117
118
119
120
121
122
123
                        decoded_text=recv_obj.decoded_texts[i],
                        decode_ids=recv_obj.decode_ids[i],
                        surr_offset=0,
                        read_offset=recv_obj.read_offsets[i],
                    )
                    self.decode_status[rid] = s
                else:
                    s = self.decode_status[rid]
                    s.decode_ids = recv_obj.decode_ids[i]

124
125
126
127
                read_ids.append(
                    self.trim_eos(
                        s.decode_ids[s.surr_offset :],
                        recv_obj.finished_reason[i],
128
                        recv_obj.no_stop_trim[i],
129
130
                    )
                )
131
                surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
132
133

            # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
Liangsheng Yin's avatar
Liangsheng Yin committed
134
            surr_texts = self.tokenizer.batch_decode(
135
                surr_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
138
139
                skip_special_tokens=recv_obj.skip_special_tokens[0],
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
            )
            read_texts = self.tokenizer.batch_decode(
140
                read_ids,
141
                skip_special_tokens=recv_obj.skip_special_tokens[0],
Liangsheng Yin's avatar
Liangsheng Yin committed
142
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
143
144
            )

145
            # Incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
146
            output_strs = []
147
148
            for i in range(bs):
                s = self.decode_status[recv_obj.rids[i]]
Liangsheng Yin's avatar
Liangsheng Yin committed
149
                new_text = read_texts[i][len(surr_texts[i]) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
150
                if recv_obj.finished_reason[i] is None:
151
152
153
154
155
156
157
158
159
                    # Streaming chunk: update the decode status
                    if len(new_text) > 0 and not new_text.endswith("�"):
                        s.decoded_text = s.decoded_text + new_text
                        s.surr_offset = s.read_offset
                        s.read_offset = len(s.decode_ids)
                        new_text = ""
                    else:
                        new_text = find_printable_text(new_text)

160
161
162
163
                output_strs.append(
                    self.trim_eos(
                        s.decoded_text + new_text,
                        recv_obj.finished_reason[i],
164
                        recv_obj.no_stop_trim[i],
165
166
                    )
                )
167
168
169
170

            self.send_to_tokenizer.send_pyobj(
                BatchStrOut(
                    rids=recv_obj.rids,
171
                    output_strs=output_strs,
172
173
                    meta_info=recv_obj.meta_info,
                    finished_reason=recv_obj.finished_reason,
Lianmin Zheng's avatar
Lianmin Zheng committed
174
                )
175
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
176
177


178
179
180
181
182
183
184
185
186
187
188
189
190
191
class LimitedCapacityDict(OrderedDict):
    def __init__(self, capacity=1 << 15, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.capacity = capacity

    def __setitem__(self, key, value):
        if len(self) >= self.capacity:
            # Remove the oldest element (first item in the dict)
            self.popitem(last=False)
        # Set the new item
        super().__setitem__(key, value)


def run_detokenizer_process(
Lianmin Zheng's avatar
Lianmin Zheng committed
192
193
194
    server_args: ServerArgs,
    port_args: PortArgs,
):
195
196
    configure_logger(server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
    try:
        manager = DetokenizerManager(server_args, port_args)
199
        manager.event_loop()
200
    except Exception:
201
202
203
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()