tokenizer_manager.py 44.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import os
22
import pickle
23
24
import signal
import sys
25
import threading
26
import time
27
import uuid
28
from collections import deque
29
30
from datetime import datetime
from http import HTTPStatus
31
32
33
34
35
36
37
38
39
40
41
42
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
43

44
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
import uvloop
import zmq
import zmq.asyncio
48
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
49

50
from sglang.srt.aio_rwlock import RWLock
51
from sglang.srt.configs.model_config import ModelConfig
Byron Hsu's avatar
Byron Hsu committed
52
53
from sglang.srt.disaggregation.conn import KVBootstrapServer
from sglang.srt.disaggregation.utils import DisaggregationMode
54
55
56
57
58
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
59
from sglang.srt.managers.io_struct import (
60
    AbortReq,
61
    BatchEmbeddingOut,
62
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
63
    BatchStrOut,
64
    BatchTokenIDOut,
65
    CloseSessionReqInput,
66
    ConfigureLoggingReq,
67
    EmbeddingReqInput,
68
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
69
    GenerateReqInput,
70
71
    GetInternalStateReq,
    GetInternalStateReqOutput,
72
73
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
74
    HealthCheckOutput,
75
76
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
77
78
    OpenSessionReqInput,
    OpenSessionReqOutput,
79
    ProfileReq,
80
81
    ProfileReqOutput,
    ProfileReqType,
82
83
84
85
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
86
87
88
    SessionParams,
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
89
90
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
91
92
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
93
94
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
95
)
96
97
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
98
from sglang.srt.server_args import PortArgs, ServerArgs
99
100
101
102
103
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
104
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

108
109
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
110

111
112
113
114
115
116
117
118
119
120
121
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

    out_list: List
    finished: bool
    event: asyncio.Event
    obj: Any

    # For metrics
    created_time: float
122
123
124
125
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
126
127
128
129
130

    # For streaming output
    last_output_offset: int = 0


131
132
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
133

Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
136
137
138
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
139
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
140
        self.server_args = server_args
141
        self.enable_metrics = server_args.enable_metrics
142
        self.log_requests = server_args.log_requests
143
        self.log_requests_level = server_args.log_requests_level
Liangsheng Yin's avatar
Liangsheng Yin committed
144

145
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
146
        context = zmq.asyncio.Context(2)
147
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
148
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
149
150
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
151
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
152
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
153

154
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
155
        self.model_path = server_args.model_path
156
        self.served_model_name = server_args.served_model_name
157
158
159
160
161
162
163
164
165
166
167
168
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )

        self.is_generation = self.model_config.is_generation
169
        self.is_image_gen = self.model_config.is_image_gen
170
171
172
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id

173
174
175
176
177
178
179
180
181
        if self.model_config.is_multimodal:
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
            )

            # We want to parallelize the image pre-processing so we create an executor for it
182
            # We create image_processor for any skip_tokenizer_init to make sure we still encode
183
184
185
186
187
188
189
190
191
            # images even with skip_tokenizer_init=False.
            self.image_processor = get_image_processor(
                self.model_config.hf_config, server_args, _processor
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
192
193
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
194
195
        else:
            self.image_processor = get_dummy_image_processor()
196

197
198
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
199
200
201
202
203
204
205
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
206

207
        # Store states
208
        self.no_create_loop = False
209
        self.rid_to_state: Dict[str, ReqState] = {}
210
211
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
212
213
214
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
215
        self.log_request_metadata = self.get_log_request_metadata()
Lianmin Zheng's avatar
Lianmin Zheng committed
216

217
218
219
220
221
222
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
223

224
225
226
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

227
228
229
230
231
232
233
234
235
236
237
238
239
        # Set after scheduler is initialized
        self.max_req_input_len = None

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

        # Communicators
240
241
242
243
244
245
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
246
247
248
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
249
250
251
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
252
253
254
255
256
257
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
258
259
260
261
262
263
264
        self.start_profile_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
265

266
        self._result_dispatcher = TypeBasedDispatcher(
267
            [
268
                (
269
270
271
272
273
274
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
275
                    self._handle_batch_output,
276
                ),
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
306
307
308
309
310
311
312
313
314
                (
                    ProfileReqOutput,
                    self.start_profile_communicator.handle_recv,
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
                (HealthCheckOutput, lambda x: None),
315
316
317
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
318
319
320
321
322
323
324
325
326
327
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        # for disaggregtion, start kv boostrap server on prefill
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
            self.bootstrap_server = KVBootstrapServer(
                self.server_args.disaggregation_bootstrap_port
            )

328
    async def generate_request(
329
        self,
330
        obj: Union[GenerateReqInput, EmbeddingReqInput],
331
        request: Optional[fastapi.Request] = None,
332
    ):
333
334
        created_time = time.time()

335
        self.auto_create_handle_loop()
336
337
338
339
340
341
342
343
344
345

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        obj.normalize_batch_and_arguments()

        if self.log_requests:
346
            max_length, skip_names, _ = self.log_request_metadata
347
            logger.info(
348
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
349
350
            )

351
        async with self.model_update_lock.reader_lock:
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)

392
393
394
395
396
        image_inputs: Dict = await self.image_processor.process_images_async(
            obj.image_data, input_text or input_ids, obj, self.max_req_input_len
        )
        if image_inputs and "input_ids" in image_inputs:
            input_ids = image_inputs["input_ids"]
397
398
399
400
        if self.is_generation:
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
401
            token_ids_logprob = obj.token_ids_logprob
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

        input_token_num = len(input_ids) if input_ids is not None else 0
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        if (
            obj.sampling_params.get("max_new_tokens") is not None
            and obj.sampling_params.get("max_new_tokens") + input_token_num
            >= self.context_len
        ):
            raise ValueError(
                f"Requested token count exceeds the model's maximum context length "
                f"of {self.context_len} tokens. You requested a total of "
                f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
                f"tokens: {input_token_num} tokens from the input messages and "
                f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
                f"completion. Please reduce the number of tokens in the input "
                f"messages or the completion to fit within the limit."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
                image_inputs,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
444
                token_ids_logprob,
445
446
447
448
449
                obj.stream,
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
450
                return_hidden_states=obj.return_hidden_states,
451
452
453
454
455
456
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
457
                image_inputs,
458
459
460
461
462
463
464
465
466
467
468
                sampling_params,
            )

        return tokenized_obj

    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
469
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        self.rid_to_state[obj.rid] = state
        self.send_to_scheduler.send_pyobj(tokenized_obj)

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]

        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
487
488
489
490
                    raise ValueError(
                        "Request is disconnected from the client side. "
                        f"Abort request {obj.rid}"
                    )
491
492
493
494
495
496
497
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
498
499
500
501
502
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
                    logger.info(msg)
                del self.rid_to_state[obj.rid]

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
525
526
527
528
                    raise ValueError(
                        "Request is disconnected from the client side. "
                        f"Abort request {obj.rid}"
                    )
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                generators.append(self._wait_one_response(tmp_obj, request))
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
607

608
609
    def flush_cache(self):
        req = FlushCacheReq()
610
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
611

612
    def abort_request(self, rid: str):
613
614
615
616
617
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
        self.send_to_scheduler.send_pyobj(req)
618

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
    ):
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
            num_steps=num_steps,
            activities=activities,
        )
        result = (await self.start_profile_communicator(req))[0]
        if not result.success:
            raise RuntimeError(result.message)
        return result
635
636

    def stop_profile(self):
637
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
638
639
        self.send_to_scheduler.send_pyobj(req)

Chayenne's avatar
Chayenne committed
640
641
642
643
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
644
    ) -> Tuple[bool, str]:
645
        self.auto_create_handle_loop()
646
647
648
649

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
650
        logger.info("Start update_weights. Load format=%s", obj.load_format)
651

652
653
654
655
656
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
657

658
659
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
660
    ) -> Tuple[bool, str]:
661
662
663
664
665
666
667
668
669
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
670
            return result.success, result.message, result.num_paused_requests
671
672
673
674
675
676
677
678
679
680
681
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
682
683
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
684

685
686
687
688
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
689
    ) -> Tuple[bool, str]:
690
        self.auto_create_handle_loop()
691
692
693
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
694
        result = (await self.init_weights_update_group_communicator(obj))[0]
695
696
697
698
699
700
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
701
    ) -> Tuple[bool, str]:
702
703
704
705
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"
706

707
708
709
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
710
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
711
            return result.success, result.message
712

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

729
730
731
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
732
733
734
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
735
        if self.server_args.dp_size == 1:
736
            return all_parameters[0]
737
738
739
        else:
            return all_parameters

740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

756
757
758
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
759
        self.auto_create_handle_loop()
760

761
762
763
764
765
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

766
        self.send_to_scheduler.send_pyobj(obj)
767
768
769
770

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
771
772
773
774
775
776
777
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
    async def get_internal_state(self) -> Dict[Any, Any]:
        req = GetInternalStateReq()
        res: List[GetInternalStateReqOutput] = (
            await self.get_internal_state_communicator(req)
        )
        return res[0].internal_state

    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
                    ]
                )
            elif self.log_requests_level == 1:
                max_length = 2048
            elif self.log_requests_level == 2:
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

818
    def configure_logging(self, obj: ConfigureLoggingReq):
819
820
821
822
823
824
825
826
827
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")
828
        self.log_request_metadata = self.get_log_request_metadata()
829

Lianmin Zheng's avatar
Lianmin Zheng committed
830
    def create_abort_task(self, obj: GenerateReqInput):
831
832
        # Abort the request if the client is disconnected.
        async def abort_request():
833
            await asyncio.sleep(1)
834
835
836
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
837
                for rid in obj.rid:
838
839
840
841
842
843
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

844
    def auto_create_handle_loop(self):
845
        if self.no_create_loop:
846
847
            return

848
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
849
        loop = asyncio.get_event_loop()
850
851
852
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
853

854
855
856
857
858
859
860
861
862
863
864
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
            loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
865
866
867
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
868
869
870

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
871
            await asyncio.sleep(5)
872

873
        # Drain requests
874
        while True:
875
            remain_num_req = len(self.rid_to_state)
876
            logger.info(
877
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
878
879
880
881
882
883
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

884
        kill_process_tree(os.getpid(), include_parent=True)
885
        sys.exit(0)
886

Lianmin Zheng's avatar
Lianmin Zheng committed
887
    async def handle_loop(self):
888
889
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
890
        while True:
891
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
892
            self._result_dispatcher(recv_obj)
893
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
894

895
    def _handle_batch_output(
896
897
898
899
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
900
901
902
903
904
905
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
                continue

906
            # Build meta_info and return value
907
908
909
910
911
912
913
914
915
916
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
                    state.obj.top_logprobs_num,
917
                    state.obj.token_ids_logprob,
918
919
920
921
922
923
924
925
926
927
928
929
930
                    state.obj.return_text_in_logprobs,
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

931
            if getattr(recv_obj, "output_hidden_states", None):
932
933
934
935
936
937
938
939
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
                out_dict = {
                    "text": recv_obj.output_strs[i],
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
940
941
942
943
944
945
946
947
                if self.server_args.stream_output and state.obj.stream:
                    output_token_ids = recv_obj.output_ids[i][
                        state.last_output_offset :
                    ]
                    state.last_output_offset = len(recv_obj.output_ids[i])
                else:
                    output_token_ids = recv_obj.output_ids[i]

948
                out_dict = {
949
                    "output_ids": output_token_ids,
950
951
                    "meta_info": meta_info,
                }
952
953
            elif isinstance(recv_obj, BatchMultimodalOut):
                raise NotImplementedError()
954
955
956
957
958
959
960
961
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
962
963
964
965
966
967
968
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time

            state.out_list.append(out_dict)
969
970
            state.event.set()

971
            # Log metrics and dump
972
973
974
975
976
977
978
979
980
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)

    def convert_logprob_style(
        self,
        meta_info: dict,
        top_logprobs_num: int,
981
        token_ids_logprob: List[int],
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.input_token_logprobs_val[recv_obj_index],
            recv_obj.input_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.output_token_logprobs_val[recv_obj_index],
            recv_obj.output_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_top_logprobs_val[recv_obj_index],
                recv_obj.input_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.output_top_logprobs_val[recv_obj_index],
                recv_obj.output_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
            )

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        if token_ids_logprob is not None:
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_token_ids_logprobs_val[recv_obj_index],
                recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
                    recv_obj.output_token_ids_logprobs_val[recv_obj_index],
                    recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
                    return_text_in_logprobs,
                )
            )

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1066
1067
1068
        if state.first_token_time == 0.0:
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1069
1070
1071
1072
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1073
1074
1075
1076
1077
1078
1079
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1080
                )
1081
1082
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1083
1084
1085

        if state.finished:
            self.metrics_collector.observe_one_finished_request(
1086
1087
                recv_obj.prompt_tokens[i],
                completion_tokens,
1088
                recv_obj.cached_tokens[i],
1089
                state.finished_time - state.created_time,
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
            # set future if the all results are recevied
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1129

1130
1131
1132
1133
1134
1135
1136
1137
1138
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1139
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1140
1141
1142
1143
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1144
class SignalHandler:
1145
    def __init__(self, tokenizer_manager: TokenizerManager):
1146
        self.tokenizer_manager = tokenizer_manager
1147
1148
1149
1150
1151

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1152
        self.tokenizer_manager.gracefully_exit = True
1153
1154
1155
1156
1157
1158


T = TypeVar("T")


class _Communicator(Generic[T]):
1159
1160
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1161
1162
1163
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1164
        self._result_event: Optional[asyncio.Event] = None
1165
        self._result_values: Optional[List[T]] = None
1166
        self._ready_queue: Deque[asyncio.Future] = deque()
1167
1168

    async def __call__(self, obj):
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1180
        self._result_values = []
1181
        await self._result_event.wait()
1182
        result_values = self._result_values
1183
1184
1185
1186
1187
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1188
1189
1190
1191
1192
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1193
            self._result_event.set()