tokenizer_manager.py 32.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
22
import signal
import sys
23
import time
24
import uuid
25
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
from sglang.srt.aio_rwlock import RWLock
34
35
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
36
37
38
39
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
40
from sglang.srt.managers.io_struct import (
41
    AbortReq,
42
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
43
    BatchStrOut,
44
    BatchTokenIDOut,
45
    CloseSessionReqInput,
46
    EmbeddingReqInput,
47
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    GenerateReqInput,
49
50
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
51
52
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
53
54
    OpenSessionReqInput,
    OpenSessionReqOutput,
55
    ProfileReq,
56
    SessionParams,
57
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
59
60
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
61
62
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
63
64
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
)
66
from sglang.srt.metrics.collector import TokenizerMetricsCollector
67
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.server_args import PortArgs, ServerArgs
69
70
71
72
73
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

77
78
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
81

@dataclasses.dataclass
class ReqState:
82
83
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86
    out_list: List
    finished: bool
    event: asyncio.Event
Lianmin Zheng's avatar
Lianmin Zheng committed
87
    obj: Any
Lianmin Zheng's avatar
Lianmin Zheng committed
88

89
90
91
92
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

93
94
95
    # For streaming output
    last_output_offset: int = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
96
97

class TokenizerManager:
98
99
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
100
101
102
103
104
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
105
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
106
        self.server_args = server_args
107
        self.enable_metrics = server_args.enable_metrics
Liangsheng Yin's avatar
Liangsheng Yin committed
108

109
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
110
        context = zmq.asyncio.Context(2)
111
112
113
114
115
116
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
117

118
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
119
        self.model_path = server_args.model_path
120
        self.served_model_name = server_args.served_model_name
121
122
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
123
            trust_remote_code=server_args.trust_remote_code,
124
            revision=server_args.revision,
125
126
127
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
128
129
            dtype=server_args.dtype,
            quantization=server_args.quantization,
130
        )
131

132
133
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len
134
        self.image_token_id = self.model_config.image_token_id
135

136
137
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
138

139
        # Create tokenizer
140
141
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
142
        else:
143
            if self.model_config.is_multimodal:
144
145
146
147
148
149
150
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
151

152
153
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
154
                    self.model_config.hf_config, server_args, self.processor
155
156
157
158
159
160
161
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
162

163
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
164
        self.to_create_loop = True
165
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
166

167
168
169
170
171
172
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
173

174
175
176
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

177
178
        # Others
        self.gracefully_exit = False
179
180
181
182
183
184
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
185
186
187
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
188
189
190
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
191

192
193
194
195
196
197
198
199
200
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

201
    async def generate_request(
202
        self,
203
        obj: Union[GenerateReqInput, EmbeddingReqInput],
204
        request: Optional[fastapi.Request] = None,
205
    ):
206
207
        created_time = time.time()

208
        self.auto_create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
209

210
211
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
212
213
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
214
215
            )

216
        obj.normalize_batch_and_arguments()
217
218
219
220
221
222
223
224

        if self.server_args.log_requests:
            logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")

        async with self.model_update_lock.reader_lock:
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
225
226
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
227
228
229
230
231
232
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response
233

234
    async def _tokenize_one_request(
235
        self,
236
        obj: Union[GenerateReqInput, EmbeddingReqInput],
237
    ):
238
239
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
240
        input_embeds = None
241
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
242
243
244
245
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
246
                    "Please add `--disable-radix-cache` when you launch the server "
Rin Intachuen's avatar
Rin Intachuen committed
247
248
249
250
251
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is None:
252
253
254
255
256
            input_ids = self.tokenizer.encode(input_text)
        else:
            input_ids = obj.input_ids

        if self.is_generation:
257
            # TODO: also support getting embeddings for multimodal models
258
            image_inputs: Dict = await self.image_processor.process_images_async(
259
                obj.image_data, input_text or input_ids, obj
260
            )
261
262
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
263
264
265
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
266
267
268
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
269

Rin Intachuen's avatar
Rin Intachuen committed
270
        if obj.input_ids is not None and len(input_ids) >= self.context_len:
271
272
273
274
275
276
277
278
279
280
281
282
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
283
            tokenized_obj = TokenizedGenerateReqInput(
284
                obj.rid,
285
286
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
287
                image_inputs,
288
289
290
291
292
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
293
294
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
295
                session_params=session_params,
296
            )
297
        elif isinstance(obj, EmbeddingReqInput):
298
            tokenized_obj = TokenizedEmbeddingReqInput(
299
                obj.rid,
300
301
302
303
                input_text,
                input_ids,
                sampling_params,
            )
304

305
        return tokenized_obj
306

307
    def _send_one_request(
308
        self,
309
        obj: Union[GenerateReqInput, EmbeddingReqInput],
310
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
311
        created_time: Optional[float] = None,
312
    ):
313
        event = asyncio.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
314
        state = ReqState([], False, event, obj, created_time=created_time)
315
        self.rid_to_state[obj.rid] = state
316
317
318
319
320
321
322
323
324
        self.send_to_scheduler.send_pyobj(tokenized_obj)

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]
325

326
327
        while True:
            try:
328
                await asyncio.wait_for(state.event.wait(), timeout=4)
329
330
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
331
332
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
333
334
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
335
            out = state.out_list[-1]
336
337
338

            state.out_list = []
            if state.finished:
339
                if self.server_args.log_requests:
340
341
                    msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
                    logger.info(msg)
342
                del self.rid_to_state[obj.rid]
343
344
345
                yield out
                break

346
            state.event.clear()
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
349
350
351
352
353

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
354

355
356
357
358
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
359
        created_time: Optional[float] = None,
360
361
362
363
364
365
366
367
368
369
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
370
371
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                generators.append(self._wait_one_response(tmp_obj, request))
372
373
374
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
Lianmin Zheng's avatar
Lianmin Zheng committed
375
376
377
378
379
380
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )
381
382
383

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
384
385
386
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
387
388
389
390
391
392
393
394
395

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
396
397
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()
398
399
400
401
402
403
404

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
405
406
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
407
408
409
410
411
412
413
414
415
416
417
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
418
419
420
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
421
422
423
424
425
426
427
428
429
430
431
432

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

433
434
    def flush_cache(self):
        req = FlushCacheReq()
435
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
436

437
438
439
440
441
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
442
        self.send_to_scheduler.send_pyobj(req)
443

444
445
446
447
448
449
450
451
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

Chayenne's avatar
Chayenne committed
452
453
454
455
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
456
    ) -> Tuple[bool, str]:
457
        self.auto_create_handle_loop()
458
459
460
461

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
462
        logger.info("Start update_weights. Load format=%s", obj.load_format)
463

464
465
466
467
468
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
469

470
471
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
472
    ) -> Tuple[bool, str]:
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            return result.success, result.message
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
            return all_success, all_message
495

496
497
498
499
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
500
    ) -> Tuple[bool, str]:
501
        self.auto_create_handle_loop()
502
503
504
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
505
        result = (await self.init_weights_update_group_communicator(obj))[0]
506
507
508
509
510
511
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
512
    ) -> Tuple[bool, str]:
513
514
515
516
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"
517

518
519
520
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
521
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
522
            return result.success, result.message
523

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

540
541
542
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
543
544
545
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
546
        if self.server_args.dp_size == 1:
547
            return all_parameters[0]
548
549
550
        else:
            return all_parameters

551
552
553
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
554
        self.auto_create_handle_loop()
555

556
557
558
559
560
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

561
        self.send_to_scheduler.send_pyobj(obj)
562
563
564
565

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
566
567
568
569
570
571
572
573
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

Lianmin Zheng's avatar
Lianmin Zheng committed
574
    def create_abort_task(self, obj: GenerateReqInput):
575
576
        # Abort the request if the client is disconnected.
        async def abort_request():
577
            await asyncio.sleep(1)
578
579
580
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
581
                for rid in obj.rid:
582
583
584
585
586
587
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

588
    def auto_create_handle_loop(self):
589
590
591
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
592
593
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
594
        self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
Lianmin Zheng's avatar
Lianmin Zheng committed
595

596
597
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
598
        self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
599
600
601

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
602
            await asyncio.sleep(5)
603
604
605
606
607

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
608
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
609
610
611
612
613
614
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

615
        kill_process_tree(os.getpid(), include_parent=True)
616
        sys.exit(0)
617

Lianmin Zheng's avatar
Lianmin Zheng committed
618
    async def handle_loop(self):
619
620
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
621
        while True:
622
            recv_obj: Union[
Chayenne's avatar
Chayenne committed
623
624
625
626
                BatchStrOut,
                BatchEmbeddingOut,
                BatchTokenIDOut,
                UpdateWeightFromDiskReqOutput,
627
                UpdateWeightsFromDistributedReqOutput,
628
                GetWeightsByNameReqOutput,
629
                InitWeightsUpdateGroupReqOutput,
630
631
            ] = await self.recv_from_detokenizer.recv_pyobj()

Lianmin Zheng's avatar
Lianmin Zheng committed
632
633
634
635
636
637
            if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
                for i, rid in enumerate(recv_obj.rids):
                    state = self.rid_to_state.get(rid, None)
                    if state is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
                    meta_info = {
                        "id": rid,
                        "finish_reason": recv_obj.finished_reasons[i],
                        "prompt_tokens": recv_obj.prompt_tokens[i],
                    }

                    if getattr(state.obj, "return_logprob", False):
                        self.convert_logprob_style(
                            meta_info,
                            state.obj.top_logprobs_num,
                            state.obj.return_text_in_logprobs,
                            recv_obj,
                            i,
                        )

653
654
655
656
657
658
659
660
                    if not isinstance(recv_obj, BatchEmbeddingOut):
                        meta_info.update(
                            {
                                "completion_tokens": recv_obj.completion_tokens[i],
                                "cached_tokens": recv_obj.cached_tokens[i],
                            }
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
661
662
663
                    if isinstance(recv_obj, BatchStrOut):
                        out_dict = {
                            "text": recv_obj.output_strs[i],
664
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
667
668
                        }
                    elif isinstance(recv_obj, BatchTokenIDOut):
                        out_dict = {
                            "token_ids": recv_obj.output_ids[i],
669
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
670
671
672
673
674
                        }
                    else:
                        assert isinstance(recv_obj, BatchEmbeddingOut)
                        out_dict = {
                            "embedding": recv_obj.embeddings[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
675
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
676
677
                        }
                    state.out_list.append(out_dict)
Lianmin Zheng's avatar
Lianmin Zheng committed
678
                    state.finished = recv_obj.finished_reasons[i] is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
679
680
681
                    state.event.set()

                    if self.enable_metrics:
Lianmin Zheng's avatar
Lianmin Zheng committed
682
683
                        completion_tokens = (
                            recv_obj.completion_tokens[i]
684
                            if getattr(recv_obj, "completion_tokens", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
                            else 0
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
687
688
689
690
691
692
693
694

                        if state.first_token_time is None:
                            state.first_token_time = time.time()
                            self.metrics_collector.observe_time_to_first_token(
                                state.first_token_time - state.created_time
                            )
                        else:
                            if completion_tokens >= 2:
695
                                # Compute time_per_output_token for the streaming case
Lianmin Zheng's avatar
Lianmin Zheng committed
696
697
698
699
700
701
702
                                self.metrics_collector.observe_time_per_output_token(
                                    (time.time() - state.first_token_time)
                                    / (completion_tokens - 1)
                                )

                        if state.finished:
                            self.metrics_collector.inc_prompt_tokens(
Lianmin Zheng's avatar
Lianmin Zheng committed
703
                                recv_obj.prompt_tokens[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
704
705
706
707
708
709
710
                            )
                            self.metrics_collector.inc_generation_tokens(
                                completion_tokens
                            )
                            self.metrics_collector.observe_e2e_request_latency(
                                time.time() - state.created_time
                            )
711
                            # Compute time_per_output_token for the non-streaming case
712
713
714
715
716
                            if (
                                hasattr(state.obj, "stream")
                                and not state.obj.stream
                                and completion_tokens >= 1
                            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
717
718
719
720
721
722
                                self.metrics_collector.observe_time_per_output_token(
                                    (time.time() - state.created_time)
                                    / completion_tokens
                                )
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
723
                    recv_obj.session_id if recv_obj.success else None
Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
                )
            elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
726
727
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
728
                else:  # self.server_args.dp_size > 1
729
730
731
732
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
Lianmin Zheng's avatar
Lianmin Zheng committed
733
734
735
736
            elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for init parameter update group"
737
                self.init_weights_update_group_communicator.handle_recv(recv_obj)
738
739
740
741
            elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
742
                self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
743
744
745
746
747
            elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
                self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
748
            elif isinstance(recv_obj, GetWeightsByNameReqOutput):
749
                self.get_weights_by_name_communicator.handle_recv(recv_obj)
Lianmin Zheng's avatar
Lianmin Zheng committed
750
751
            else:
                raise ValueError(f"Invalid object: {recv_obj=}")
752

Liangsheng Yin's avatar
Liangsheng Yin committed
753
    def convert_logprob_style(
754
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
755
        meta_info: dict,
756
757
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Lianmin Zheng's avatar
Lianmin Zheng committed
758
759
        recv_obj: BatchStrOut,
        recv_obj_index: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
760
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.input_token_logprobs_val[recv_obj_index],
            recv_obj.input_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.output_token_logprobs_val[recv_obj_index],
            recv_obj.output_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
            recv_obj_index
        ]

        if top_logprobs_num > 0:
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_top_logprobs_val[recv_obj_index],
                recv_obj.input_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
780
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
781
782
783
784
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.output_top_logprobs_val[recv_obj_index],
                recv_obj.output_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
785
            )
786

787
    def detokenize_logprob_tokens(
Lianmin Zheng's avatar
Lianmin Zheng committed
788
789
790
791
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
792
    ):
793
        if not decode_to_text:
Lianmin Zheng's avatar
Lianmin Zheng committed
794
795
796
797
798
799
800
801
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
802

Lianmin Zheng's avatar
Lianmin Zheng committed
803
804
805
806
807
808
    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
809
810
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
Lianmin Zheng's avatar
Lianmin Zheng committed
811
812
813
814
815
816
817
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
818
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
819
820
821
            else:
                ret.append(None)
        return ret
822
823
824
825
826
827
828
829
830
831
832


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857


T = TypeVar("T")


class _Communicator(Generic[T]):
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
        self._result_future: Optional[asyncio.Future] = None
        self._result_values: Optional[List[T]] = None

    async def __call__(self, obj):
        self._sender.send_pyobj(obj)
        self._result_future = asyncio.Future()
        self._result_values = []
        await self._result_future
        result_values = self._result_values
        self._result_future = self._result_values = None
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
            self._result_future.set_result(None)