detokenizer_manager.py 7.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""DetokenizerManager is a process that detokenizes the token ids."""
17

18
import dataclasses
19
20
import logging
from collections import OrderedDict
21
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23

import zmq
Liangsheng Yin's avatar
Liangsheng Yin committed
24

Lianmin Zheng's avatar
Lianmin Zheng committed
25
from sglang.srt.hf_transformers_utils import get_tokenizer
26
27
28
29
from sglang.srt.managers.io_struct import (
    BatchEmbeddingOut,
    BatchStrOut,
    BatchTokenIDOut,
30
    UpdateWeightReqOutput,
31
)
32
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.server_args import PortArgs, ServerArgs
34
from sglang.srt.utils import configure_logger, kill_parent_process
35
from sglang.utils import find_printable_text, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
41
@dataclasses.dataclass
class DecodeStatus:
42
43
    """Store the status of incremental decoding."""

44
    vid: int
45
46
47
48
49
50
    decoded_text: str
    decode_ids: List[int]
    surr_offset: int
    read_offset: int


Lianmin Zheng's avatar
Lianmin Zheng committed
51
class DetokenizerManager:
52
53
    """DetokenizerManager is a process that detokenizes the token ids."""

Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
59
        # Init inter-process communication
60
        context = zmq.Context(2)
61
        self.recv_from_scheduler = context.socket(zmq.PULL)
62
        self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
Lianmin Zheng's avatar
Lianmin Zheng committed
63
64

        self.send_to_tokenizer = context.socket(zmq.PUSH)
65
        self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
Lianmin Zheng's avatar
Lianmin Zheng committed
66

67
68
69
70
71
72
73
74
        if server_args.skip_tokenizer_init:
            self.tokenizer = None
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
75

76
        self.decode_status = LimitedCapacityDict()
77

78
    def event_loop(self):
79
80
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
81
        while True:
82
            recv_obj = self.recv_from_scheduler.recv_pyobj()
83
84

            if isinstance(recv_obj, BatchEmbeddingOut):
85
                # If it is embedding model, no detokenization is needed.
86
87
88
89
90
91
92
93
94
                self.send_to_tokenizer.send_pyobj(
                    BatchEmbeddingOut(
                        rids=recv_obj.rids,
                        embeddings=recv_obj.embeddings,
                        meta_info=recv_obj.meta_info,
                        finished_reason=recv_obj.finished_reason,
                    )
                )
                continue
95
96
97
98
99
100
            elif isinstance(recv_obj, UpdateWeightReqOutput):
                # If it is a weight update request, no detokenization is needed.
                self.send_to_tokenizer.send_pyobj(recv_obj)
                continue
            elif self.tokenizer is None:
                # If the tokenizer is skipped, no detokenization is needed
101
102
103
                self.send_to_tokenizer.send_pyobj(recv_obj)
                continue

104
            assert isinstance(recv_obj, BatchTokenIDOut)
105
106
107
108
109
110
            bs = len(recv_obj.rids)

            # Initialize decode status
            read_ids, surr_ids = [], []
            for i in range(bs):
                rid = recv_obj.rids[i]
111
112
                vid = recv_obj.vids[i]
                if rid not in self.decode_status or self.decode_status[rid].vid != vid:
113
                    s = DecodeStatus(
114
                        vid=vid,
115
116
117
118
119
120
121
122
123
124
125
126
                        decoded_text=recv_obj.decoded_texts[i],
                        decode_ids=recv_obj.decode_ids[i],
                        surr_offset=0,
                        read_offset=recv_obj.read_offsets[i],
                    )
                    self.decode_status[rid] = s
                else:
                    s = self.decode_status[rid]
                    s.decode_ids = recv_obj.decode_ids[i]

                read_ids.append(s.decode_ids[s.surr_offset :])
                surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
127
128

            # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
Liangsheng Yin's avatar
Liangsheng Yin committed
129
            surr_texts = self.tokenizer.batch_decode(
130
                surr_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
131
132
133
134
                skip_special_tokens=recv_obj.skip_special_tokens[0],
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
            )
            read_texts = self.tokenizer.batch_decode(
135
                read_ids,
136
                skip_special_tokens=recv_obj.skip_special_tokens[0],
Liangsheng Yin's avatar
Liangsheng Yin committed
137
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
138
139
            )

140
            # Incremental decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
141
            output_strs = []
142
143
            for i in range(bs):
                s = self.decode_status[recv_obj.rids[i]]
Liangsheng Yin's avatar
Liangsheng Yin committed
144
                new_text = read_texts[i][len(surr_texts[i]) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
145
                if recv_obj.finished_reason[i] is None:
146
147
148
149
150
151
152
153
154
155
                    # Streaming chunk: update the decode status
                    if len(new_text) > 0 and not new_text.endswith("�"):
                        s.decoded_text = s.decoded_text + new_text
                        s.surr_offset = s.read_offset
                        s.read_offset = len(s.decode_ids)
                        new_text = ""
                    else:
                        new_text = find_printable_text(new_text)

                output_strs.append(s.decoded_text + new_text)
156

157
                # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
158
159
160
161
162
163
164
165
                if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
                    pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
                    if pos != -1:
                        output_strs[i] = output_strs[i][:pos]

            self.send_to_tokenizer.send_pyobj(
                BatchStrOut(
                    rids=recv_obj.rids,
166
                    output_strs=output_strs,
167
168
                    meta_info=recv_obj.meta_info,
                    finished_reason=recv_obj.finished_reason,
Lianmin Zheng's avatar
Lianmin Zheng committed
169
                )
170
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
171
172


173
174
175
176
177
178
179
180
181
182
183
184
185
186
class LimitedCapacityDict(OrderedDict):
    def __init__(self, capacity=1 << 15, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.capacity = capacity

    def __setitem__(self, key, value):
        if len(self) >= self.capacity:
            # Remove the oldest element (first item in the dict)
            self.popitem(last=False)
        # Set the new item
        super().__setitem__(key, value)


def run_detokenizer_process(
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
189
    server_args: ServerArgs,
    port_args: PortArgs,
):
190
191
    configure_logger(server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
192
193
    try:
        manager = DetokenizerManager(server_args, port_args)
194
        manager.event_loop()
195
    except Exception:
196
197
198
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()