tokenizer_manager.py 28.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import asyncio
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import json
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
24
import signal
import sys
25
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
37
38
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
39
40
41
42
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
43
from sglang.srt.managers.io_struct import (
44
    AbortReq,
45
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    BatchStrOut,
47
    BatchTokenIDOut,
48
    EmbeddingReqInput,
49
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
50
    GenerateReqInput,
51
52
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
53
    ProfileReq,
54
    RewardReqInput,
55
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    TokenizedGenerateReqInput,
57
    TokenizedRewardReqInput,
58
59
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
60
)
61
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from sglang.srt.server_args import PortArgs, ServerArgs
63
64
65
66
67
68
from sglang.srt.utils import (
    get_zmq_socket,
    is_generation_model,
    is_multimodal_model,
    kill_child_process,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
71

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

72
73
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

@dataclasses.dataclass
class ReqState:
77
78
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
81
82
83
84
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
85
86
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
87
88
89
90
91
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
92
93
        self.server_args = server_args

94
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
95
        context = zmq.asyncio.Context(2)
96
97
98
99
100
101
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
102

103
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
104
        self.model_path = server_args.model_path
105
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
106
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
107
108
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
109
            model_override_args=json.loads(server_args.json_model_override_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
110
        )
111
112
113
        self.is_generation = is_generation_model(
            self.hf_config.architectures, self.server_args.is_embedding
        )
114
115
116
        self.context_len = server_args.context_length or get_context_length(
            self.hf_config
        )
117
118
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
119

120
        # Create tokenizer
121
122
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
123
        else:
124
            if is_multimodal_model(self.hf_config.architectures):
125
126
127
128
129
130
131
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
132

133
134
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
135
                    self.hf_config, server_args, self.processor
136
137
138
139
140
141
142
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
143

144
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
145
        self.to_create_loop = True
146
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
147

148
        # For update model weights
149
150
151
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

152
153
154
        # Others
        self.gracefully_exit = False

155
    async def generate_request(
156
        self,
157
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
158
        request: Optional[fastapi.Request] = None,
159
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
160
        if self.to_create_loop:
161
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
162

163
        while self.model_update_lock.locked():
164
            await asyncio.sleep(0.001)
165

166
167
168
169
170
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
            )

171
        obj.post_init()
172
        is_single = obj.is_single
173
174
175
176
177
178
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
179

180
    async def _send_single_request(
181
        self,
182
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
183
        index: Optional[int] = None,
184
        input_id_index: Optional[int] = None,
185
        is_cache_for_prefill: Optional[bool] = False,
186
    ):
yichuan~'s avatar
yichuan~ committed
187
        if not is_cache_for_prefill:  # The normal case with a single prompt
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            if index is None:
                rid = obj.rid
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = obj.text if obj.text is not None else None
                    input_ids = obj.input_ids

                sampling_params = self._get_sampling_params(obj.sampling_params)
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
207
                        obj.image_data, input_text or input_ids, obj
208
                    )
209
210
                    if image_inputs and "input_ids" in image_inputs:
                        input_ids = image_inputs["input_ids"]
211
212
213
                    return_logprob = obj.return_logprob
                    logprob_start_len = obj.logprob_start_len
                    top_logprobs_num = obj.top_logprobs_num
214
            else:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
                rid = obj.rid[index]
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv[index]
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text[input_id_index]
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = (
                        obj.text[input_id_index] if obj.text is not None else None
                    )
                    input_ids = obj.input_ids[input_id_index]
Lianmin Zheng's avatar
Lianmin Zheng committed
231

232
233
234
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
235
                        obj.image_data[index], input_text or input_ids, obj
236
                    )
237
238
                    if image_inputs and "input_ids" in image_inputs:
                        input_ids = image_inputs["input_ids"]
239
240
241
                    return_logprob = obj.return_logprob[index]
                    logprob_start_len = obj.logprob_start_len[index]
                    top_logprobs_num = obj.top_logprobs_num[index]
242

243
            self._validate_input_length(input_ids)
244

yichuan~'s avatar
yichuan~ committed
245
        else:  # A prefill request to cache the common prompt for parallel sampling
246
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
247
248
            if obj.text is not None:
                if isinstance(obj.text, list):
249
                    input_text = obj.text[input_id_index]
yichuan~'s avatar
yichuan~ committed
250
251
252
253
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
254
255
256
257
258
259
260
261
262
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
263
                        input_ids = obj.input_ids[input_id_index]
264
265
266
267
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
268
            else:
yichuan~'s avatar
yichuan~ committed
269
270
271
272
273
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
274
                    input_ids = obj.input_ids[input_id_index]
yichuan~'s avatar
yichuan~ committed
275
276
277
278
279
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

280
281
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
282
            image_inputs = await self.image_processor.process_images_async(
283
                obj.image_data[0], input_text or input_ids, obj
284
            )
285
286
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
287
288
289
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
290

291
        # Send to the controller
292
293
294
295
296
        if self.is_generation:
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
297
                image_inputs,
298
299
300
301
302
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
303
                (
304
                    obj.lora_path[input_id_index]
305
306
307
                    if isinstance(obj.lora_path, list)
                    else obj.lora_path
                ),
308
            )
309
        elif isinstance(obj, EmbeddingReqInput):
310
311
312
313
314
315
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
316
317
318
319
320
321
322
323
        else:
            assert isinstance(obj, RewardReqInput)
            tokenized_obj = TokenizedRewardReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
324

325
        self.send_to_scheduler.send_pyobj(tokenized_obj)
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        return rid, input_ids

    async def _handle_single_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
        input_id_index: Optional[int] = None,
        is_cache_for_prefill: Optional[bool] = False,
    ):
        rid, input_ids = await self._send_single_request(
            obj,
            index,
            input_id_index=input_id_index,
            is_cache_for_prefill=is_cache_for_prefill,
        )
342

343
        # Recv results
344
345
346
        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
347

348
        if not is_cache_for_prefill:
349
            async for response in self._wait_for_response(state, obj, rid, request):
350
                yield response
351
        else:
Ying Sheng's avatar
Ying Sheng committed
352
            assert self.is_generation
353
            await self._wait_for_cache_prefill_response(state, obj, rid, request)
354
            yield input_ids
355

356
    async def _handle_batch_request(
357
        self,
358
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
359
        request: Optional[fastapi.Request] = None,
360
    ):
361
        batch_size = obj.batch_size
362
363
364
365
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
366
                # Send prefill requests to cache the common prefix
367
368
369
370
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
371
372
373
374
375
                        obj,
                        request,
                        index=i,
                        input_id_index=i,
                        is_cache_for_prefill=True,
376
377
378
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
379
                if input_id_result is not None:
380
381
382
                    obj.input_ids = input_id_result
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
383

384
        # First send out all requests
385
        generators = []
386
387
388
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
389
                    continue
390
391
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
392
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
393
                    index += batch_size - 1 - i
394

395
396
397
                rid, _ = await self._send_single_request(
                    obj, index, input_id_index=i, is_cache_for_prefill=False
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
398
399

                event = asyncio.Event()
400
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
401
                self.rid_to_state[rid] = state
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
                generators.append(
                    self._wait_for_response(
                        state,
                        obj,
                        rid,
                        request,
                        index=index,
                        response_index=len(generators),
                    )
                )

        # Then process the responses based on streaming option
        is_stream = hasattr(obj, "stream") and obj.stream

        tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
418
        output_list = [None] * len(tasks)
419

420
        # Fetch results
421
422
423
424
        while tasks:
            done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

            for task in done:
425
                cur_index = tasks.index(task)
426
427
428
429
430
431
432

                try:
                    result = task.result()

                    if is_stream:
                        yield result
                    else:
433
                        output_list[result["index"]] = result
434

435
436
                    tasks[cur_index] = asyncio.create_task(
                        generators[cur_index].__anext__()
Liangsheng Yin's avatar
Liangsheng Yin committed
437
                    )
438
                except StopAsyncIteration:
439
440
                    del generators[cur_index]
                    del tasks[cur_index]
441
442
443

        if not is_stream:
            yield output_list
444

445
    def _validate_input_length(self, input_ids: List[int]):
446
447
448
449
450
451
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

452
    def _get_sampling_params(self, sampling_params_data: dict):
453
454
455
456
457
458
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

459
460
461
    async def _wait_for_response(
        self,
        state: ReqState,
462
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
463
        rid: str,
464
465
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
466
        response_index: int = 0,
467
    ):
468
469
        while True:
            try:
470
                await asyncio.wait_for(state.event.wait(), timeout=4)
471
472
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
473
474
                    for rid in [obj.rid] if obj.is_single else obj.rid:
                        self.abort_request(rid)
475
476
477
                    raise ValueError(f"Abort request {rid}")
                continue

478
479
480
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
481
482
483
484
485
486
                    obj.return_logprob if index is None else obj.return_logprob[index],
                    (
                        obj.top_logprobs_num
                        if index is None
                        else obj.top_logprobs_num[index]
                    ),
487
488
                    obj.return_text_in_logprobs,
                )
489
            else:  # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
490
                out = state.out_list[-1]
491

492
493
            out["index"] = response_index

Ying Sheng's avatar
Ying Sheng committed
494
            # Log requests
495
            if self.server_args.log_requests and state.finished:
496
                logger.info(f"in={obj}, out={out}")
497
498
499
500
501
502
503

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

504
            state.event.clear()
505
506
            yield out

507
508
509
510
511
    async def _wait_for_cache_prefill_response(
        self,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
512
        request: Optional[fastapi.Request] = None,
513
    ):
514
515
516
517
518
519
520
521
522
523
524
525
526
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
527

528
529
    def flush_cache(self):
        req = FlushCacheReq()
530
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
531

532
533
534
535
536
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
537
        self.send_to_scheduler.send_pyobj(req)
538

539
540
541
542
543
544
545
546
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

547
548
549
550
551
    async def get_memory_pool_size(self):
        if self.to_create_loop:
            self.create_handle_loop()

        req = GetMemPoolSizeReq()
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        ret = None

        if self.server_args.dp_size == 1:
            self.send_to_scheduler.send_pyobj(req)
            self.mem_pool_size = asyncio.Future()
            res = await self.mem_pool_size
            ret = res.size

        else: # self.server_args.dp_size > 1
            self.send_to_scheduler.send_pyobj(req)
            self.mem_pool_size = asyncio.Future()
            self.mem_pool_size_tmp = []
            res = await self.mem_pool_size
            ret = [r.size for r in res]
            
        return ret
568

569
570
571
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
572
573
574
575
576
577
578
579
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        
            if self.server_args.dp_size == 1:
                async with self.model_update_lock:
                    # wait for the previous generation requests to finish
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    self.send_to_scheduler.send_pyobj(obj)
                    self.model_update_result = asyncio.Future()
                    result = await self.model_update_result
                    if result.success:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                return result.success, result.message

            else: # self.server_args.dp_size > 1

                # There will be dp_size number of response from the detokenizer
                async with self.model_update_lock:
                    # wait for the previous generation requests to finish
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    self.send_to_scheduler.send_pyobj(obj)
                    self.model_update_result = asyncio.Future()
                    self.model_update_tmp = []
                    result = await self.model_update_result

                    all_success = all([r.success for r in result])
                    if all_success is True:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                    all_message = [r.message for r in result]
                    all_message = " | ".join(all_message)
                    
                return all_success, all_message

617
618
619
        else:
            return False, "Another update is in progress. Please try again later."

Lianmin Zheng's avatar
Lianmin Zheng committed
620
    def create_abort_task(self, obj: GenerateReqInput):
621
622
        # Abort the request if the client is disconnected.
        async def abort_request():
623
            await asyncio.sleep(1)
624
625
626
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
627
                for rid in obj.rid:
628
629
630
631
632
633
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

634
    def create_handle_loop(self):
635
636
637
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
638
639
640
641
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        loop.create_task(self.sigterm_watchdog())

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
            await asyncio.sleep(60)

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
                f"gracefully exiting... remaining number of requests {remain_num_req}"
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

        kill_child_process(include_self=True)
        sys.exit(-1)

Lianmin Zheng's avatar
Lianmin Zheng committed
664
    async def handle_loop(self):
665
666
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
667
        while True:
668
669
670
671
672
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
673
674
675
676
677
678
679
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
                else: # self.server_args.dp_size > 1
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
680
                continue
681
            elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
682
683
684
685
686
687
688
                if self.server_args.dp_size == 1:
                    self.mem_pool_size.set_result(recv_obj)
                else: # self.sever_args.dp_size > 1
                    self.mem_pool_size_tmp.append(recv_obj)
                    # set future if the all results are received
                    if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
                        self.mem_pool_size.set_result(self.mem_pool_size_tmp)
689
                continue
690

691
692
693
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
694

695
696
697
698
699
700
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
701
702
703
704
705
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
706
707
                elif isinstance(recv_obj, BatchTokenIDOut):
                    out_dict = {
708
                        "token_ids": recv_obj.output_ids[i],
709
710
711
                        "meta_info": recv_obj.meta_info[i],
                    }

712
713
714
715
716
717
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
718
719
720
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
721

Liangsheng Yin's avatar
Liangsheng Yin committed
722
    def convert_logprob_style(
723
724
725
726
727
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
728
    ):
729
        if return_logprob:
730
731
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
732
            )
733
734
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
735
            )
736
737

            if top_logprobs_num > 0:
738
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
739
                    self.detokenize_top_logprobs_tokens(
740
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
741
742
                        return_text_in_logprobs,
                    )
743
                )
744
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
745
                    self.detokenize_top_logprobs_tokens(
746
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
747
                    )
748
                )
749
750
        return ret

751
752
753
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
754
        # TODO(lianmin): This should run on DetokenizerManager
755
756
757
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

758
        assert self.tokenizer is not None
759
760
761
762
763
764
765
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

766
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
767
768
769
770
771
772
773
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
774
        return top_logprobs
775
776
777
778
779
780
781
782
783
784
785


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True