detokenizer_manager.py 4.89 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
"""DetokenizerManager is a process that detokenizes the token ids."""
2

Lianmin Zheng's avatar
Lianmin Zheng committed
3
import asyncio
4
import dataclasses
5
import inspect
6
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
7
8
9
10

import uvloop
import zmq
import zmq.asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
11

Lianmin Zheng's avatar
Lianmin Zheng committed
12
from sglang.srt.hf_transformers_utils import get_tokenizer
13
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
16
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
19
20

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


21
22
23
24
25
26
27
28
@dataclasses.dataclass
class DecodeStatus:
    decoded_text: str
    decode_ids: List[int]
    surr_offset: int
    read_offset: int


Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class DetokenizerManager:
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
        context = zmq.asyncio.Context(2)
        self.recv_from_router = context.socket(zmq.PULL)
        self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")

        self.send_to_tokenizer = context.socket(zmq.PUSH)
        self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

        self.tokenizer = get_tokenizer(
            server_args.tokenizer_path,
            tokenizer_mode=server_args.tokenizer_mode,
            trust_remote_code=server_args.trust_remote_code,
        )

48
49
        self.decode_status = {}

Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
    async def handle_loop(self):
        while True:
52
53
            recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
            assert isinstance(recv_obj, BatchTokenIDOut)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            bs = len(recv_obj.rids)

            # FIXME: incremental detokenize is not compatible with jump forward
            # Initialize decode status
            read_ids, surr_ids = [], []
            for i in range(bs):
                rid = recv_obj.rids[i]
                if rid not in self.decode_status:
                    s = DecodeStatus(
                        decoded_text=recv_obj.decoded_texts[i],
                        decode_ids=recv_obj.decode_ids[i],
                        surr_offset=0,
                        read_offset=recv_obj.read_offsets[i],
                    )
                    self.decode_status[rid] = s
                else:
                    s = self.decode_status[rid]
                    s.decode_ids = recv_obj.decode_ids[i]

                read_ids.append(s.decode_ids[s.surr_offset :])
                surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
75
76

            # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
Liangsheng Yin's avatar
Liangsheng Yin committed
77
            surr_texts = self.tokenizer.batch_decode(
78
                surr_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
79
80
81
82
                skip_special_tokens=recv_obj.skip_special_tokens[0],
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
            )
            read_texts = self.tokenizer.batch_decode(
83
                read_ids,
84
                skip_special_tokens=recv_obj.skip_special_tokens[0],
Liangsheng Yin's avatar
Liangsheng Yin committed
85
                spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
86
87
88
89
            )

            # Trim stop str
            # TODO(lmzheng): handle the case where multiple stop strs are hit
Liangsheng Yin's avatar
Liangsheng Yin committed
90
            output_strs = []
91
92
            for i in range(bs):
                s = self.decode_status[recv_obj.rids[i]]
Liangsheng Yin's avatar
Liangsheng Yin committed
93
                new_text = read_texts[i][len(surr_texts[i]) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
94
                if recv_obj.finished_reason[i] is None:
95
96
97
98
99
100
101
102
103
104
                    # Streaming chunk: update the decode status
                    if len(new_text) > 0 and not new_text.endswith("�"):
                        s.decoded_text = s.decoded_text + new_text
                        s.surr_offset = s.read_offset
                        s.read_offset = len(s.decode_ids)
                        new_text = ""
                    else:
                        new_text = find_printable_text(new_text)

                output_strs.append(s.decoded_text + new_text)
105
106
107
108
109
110
111
112
113

                if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
                    pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
                    if pos != -1:
                        output_strs[i] = output_strs[i][:pos]

            self.send_to_tokenizer.send_pyobj(
                BatchStrOut(
                    rids=recv_obj.rids,
114
                    output_strs=output_strs,
115
116
                    meta_info=recv_obj.meta_info,
                    finished_reason=recv_obj.finished_reason,
Lianmin Zheng's avatar
Lianmin Zheng committed
117
                )
118
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
121
122
123
124
125


def start_detokenizer_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    pipe_writer,
):
126
127
    graceful_registry(inspect.currentframe().f_code.co_name)

Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
    try:
        manager = DetokenizerManager(server_args, port_args)
130
    except Exception:
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
133
134
135
        pipe_writer.send(get_exception_traceback())
        raise
    pipe_writer.send("init ok")
    loop = asyncio.get_event_loop()
    loop.run_until_complete(manager.handle_loop())