detokenizer_manager.py 6.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""DetokenizerManager is a process that detokenizes the token ids."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import asyncio
19
import dataclasses
20
import inspect
21
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25

import uvloop
import zmq
import zmq.asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
26

Lianmin Zheng's avatar
Lianmin Zheng committed
27
from sglang.srt.hf_transformers_utils import get_tokenizer
28
29
30
31
32
from sglang.srt.managers.io_struct import (
    BatchEmbeddingOut,
    BatchStrOut,
    BatchTokenIDOut,
)
33
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
Lianmin Zheng's avatar
Lianmin Zheng committed
34
from sglang.srt.server_args import PortArgs, ServerArgs
35
from sglang.utils import find_printable_text, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


40
41
@dataclasses.dataclass
class DecodeStatus:
42
    vid: int
43
44
45
46
47
48
    decoded_text: str
    decode_ids: List[int]
    surr_offset: int
    read_offset: int


Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
52
53
54
55
56
57
58
59
60
61
class DetokenizerManager:
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
        context = zmq.asyncio.Context(2)
        self.recv_from_router = context.socket(zmq.PULL)
        self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")

        self.send_to_tokenizer = context.socket(zmq.PUSH)
        self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

62
63
64
65
66
67
68
69
        if server_args.skip_tokenizer_init:
            self.tokenizer = None
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
70

71
72
        self.decode_status = {}

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
    async def handle_loop(self):
        while True:
75
            recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
76
77
78
79
80
81
82
83
84
85
86
87

            if isinstance(recv_obj, BatchEmbeddingOut):
                self.send_to_tokenizer.send_pyobj(
                    BatchEmbeddingOut(
                        rids=recv_obj.rids,
                        embeddings=recv_obj.embeddings,
                        meta_info=recv_obj.meta_info,
                        finished_reason=recv_obj.finished_reason,
                    )
                )
                continue

88
            assert isinstance(recv_obj, BatchTokenIDOut)
89
90
            bs = len(recv_obj.rids)

91
92
93
94
95
            if self.tokenizer is None:
                # Send BatchTokenIDOut if no tokenizer init'ed.
                self.send_to_tokenizer.send_pyobj(recv_obj)
                continue

96
97
98
99
            # Initialize decode status
            read_ids, surr_ids = [], []
            for i in range(bs):
                rid = recv_obj.rids[i]
100
101
                vid = recv_obj.vids[i]
                if rid not in self.decode_status or self.decode_status[rid].vid != vid:
102
                    s = DecodeStatus(
103
                        vid=vid,
104
105
106
107
108
109
110
111
112
113
114
115
                        decoded_text=recv_obj.decoded_texts[i],
                        decode_ids=recv_obj.decode_ids[i],
                        surr_offset=0,
                        read_offset=recv_obj.read_offsets[i],
                    )
                    self.decode_status[rid] = s
                else:
                    s = self.decode_status[rid]
                    s.decode_ids = recv_obj.decode_ids[i]

                read_ids.append(s.decode_ids[s.surr_offset :])
                surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
116
117

            # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
Liangsheng Yin's avatar
Liangsheng Yin committed
118
            surr_texts = self.tokenizer.batch_decode(
119
                surr_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
123
                skip_special_tokens=recv_obj.skip_special_tokens[0],
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
            )
            read_texts = self.tokenizer.batch_decode(
124
                read_ids,
125
                skip_special_tokens=recv_obj.skip_special_tokens[0],
Liangsheng Yin's avatar
Liangsheng Yin committed
126
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
127
128
129
130
            )

            # Trim stop str
            # TODO(lmzheng): handle the case where multiple stop strs are hit
Liangsheng Yin's avatar
Liangsheng Yin committed
131
            output_strs = []
132
133
            for i in range(bs):
                s = self.decode_status[recv_obj.rids[i]]
Liangsheng Yin's avatar
Liangsheng Yin committed
134
                new_text = read_texts[i][len(surr_texts[i]) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
135
                if recv_obj.finished_reason[i] is None:
136
137
138
139
140
141
142
143
144
145
                    # Streaming chunk: update the decode status
                    if len(new_text) > 0 and not new_text.endswith("�"):
                        s.decoded_text = s.decoded_text + new_text
                        s.surr_offset = s.read_offset
                        s.read_offset = len(s.decode_ids)
                        new_text = ""
                    else:
                        new_text = find_printable_text(new_text)

                output_strs.append(s.decoded_text + new_text)
146
147
148
149
150
151
152
153
154

                if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
                    pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
                    if pos != -1:
                        output_strs[i] = output_strs[i][:pos]

            self.send_to_tokenizer.send_pyobj(
                BatchStrOut(
                    rids=recv_obj.rids,
155
                    output_strs=output_strs,
156
157
                    meta_info=recv_obj.meta_info,
                    finished_reason=recv_obj.finished_reason,
Lianmin Zheng's avatar
Lianmin Zheng committed
158
                )
159
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
162
163
164
165
166
167
168


def start_detokenizer_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    pipe_writer,
):
    try:
        manager = DetokenizerManager(server_args, port_args)
169
    except Exception:
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
172
173
174
        pipe_writer.send(get_exception_traceback())
        raise
    pipe_writer.send("init ok")
    loop = asyncio.get_event_loop()
    loop.run_until_complete(manager.handle_loop())