engine.py 31.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The entry point of inference server. (SRT = SGLang Runtime)

This file implements python APIs for the inference engine.
"""

import asyncio
import atexit
import dataclasses
import logging
import multiprocessing as mp
import os
import signal
import threading
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union

30
31
import zmq
import zmq.asyncio
32
from PIL.Image import Image
33

34
35
36
37
38
39
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

import torch
import uvloop

40
from sglang.srt.entrypoints.EngineBase import EngineBase
41
42
43
44
45
46
47
48
49
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    GetWeightsByNameReqInput,
    InitWeightsUpdateGroupReqInput,
50
    LoadLoRAAdapterReqInput,
51
    MultimodalDataInputFormat,
52
53
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
54
55
    RpcReqInput,
    RpcReqOutput,
56
    UnloadLoRAAdapterReqInput,
57
    UpdateWeightFromDiskReqInput,
58
59
60
61
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
62
from sglang.srt.managers.template_manager import TemplateManager
63
from sglang.srt.managers.tokenizer_manager import TokenizerManager
64
65
66
67
68
69
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
    MultiprocessingSerializer,
    assert_pkg_version,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
70
    get_bool_env_var,
71
    get_zmq_socket,
72
    is_cuda,
73
    kill_process_tree,
74
    launch_dummy_health_check_server,
75
76
77
78
79
80
81
82
83
    prepare_model_and_tokenizer,
    set_prometheus_multiproc_dir,
    set_ulimit,
)
from sglang.version import __version__

logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

84
85
_is_cuda = is_cuda()

86

87
class Engine(EngineBase):
88
89
90
91
    """
    The entry point to the inference engine.

    - The engine consists of three components:
92
        1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
93
94
95
96
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
97
98
    1. The HTTP server, Engine, and TokenizerManager all run in the main process.
    2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port.
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    """

    def __init__(self, **kwargs):
        """
        The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
        Please refer to `ServerArgs` for the documentation.
        """
        if "server_args" in kwargs:
            # Directly load server_args
            server_args = kwargs["server_args"]
        else:
            # Construct server_args from kwargs
            if "log_level" not in kwargs:
                # Do not print logs by default
                kwargs["log_level"] = "error"
            server_args = ServerArgs(**kwargs)

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
116
        # Shutdown the subprocesses automatically when the program exits
117
118
        atexit.register(self.shutdown)

119
        # Allocate ports for inter-process communications
120
        self.port_args = PortArgs.init_new(server_args)
121
122
        logger.info(f"{server_args=}")

123
        # Launch subprocesses
124
        tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
125
            server_args=server_args,
126
            port_args=self.port_args,
127
        )
128
        self.server_args = server_args
129
        self.tokenizer_manager = tokenizer_manager
130
        self.template_manager = template_manager
131
132
        self.scheduler_info = scheduler_info

133
134
        context = zmq.Context(2)
        self.send_to_rpc = get_zmq_socket(
135
            context, zmq.DEALER, self.port_args.rpc_ipc_name, True
136
137
        )

138
139
140
141
142
143
144
    def generate(
        self,
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
        sampling_params: Optional[Union[List[Dict], Dict]] = None,
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
145
146
147
148
149
150
        # The image input. It can be an image instance, file name, URL, or base64 encoded string.
        # Can be formatted as:
        # - Single image for a single request
        # - List of images (one per request in a batch)
        # - List of lists of images (multiple images per request)
        # See also python/sglang/srt/utils.py:load_image for more details.
151
152
153
        image_data: Optional[MultimodalDataInputFormat] = None,
        audio_data: Optional[MultimodalDataInputFormat] = None,
        video_data: Optional[MultimodalDataInputFormat] = None,
154
155
156
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
157
        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
158
159
        lora_path: Optional[List[Optional[str]]] = None,
        custom_logit_processor: Optional[Union[List[str], str]] = None,
160
        return_hidden_states: bool = False,
161
        stream: bool = False,
162
163
164
        bootstrap_host: Optional[Union[List[str], str]] = None,
        bootstrap_port: Optional[Union[List[int], int]] = None,
        bootstrap_room: Optional[Union[List[int], int]] = None,
165
        data_parallel_rank: Optional[int] = None,
166
167
168
169
170
    ) -> Union[Dict, Iterator[Dict]]:
        """
        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
        Please refer to `GenerateReqInput` for the documentation.
        """
171
172
        if self.server_args.enable_dp_attention:
            if data_parallel_rank is None:
173
                logger.debug("data_parallel_rank not provided, using default dispatch")
174
175
176
177
178
179
180
            elif data_parallel_rank < 0:
                raise ValueError("data_parallel_rank must be non-negative")
            elif data_parallel_rank >= self.server_args.dp_size:
                raise ValueError(
                    f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
                )

181
182
183
184
        obj = GenerateReqInput(
            text=prompt,
            input_ids=input_ids,
            sampling_params=sampling_params,
185
            image_data=image_data,
186
187
            audio_data=audio_data,
            video_data=video_data,
188
189
190
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
191
            token_ids_logprob=token_ids_logprob,
192
193
            lora_path=lora_path,
            custom_logit_processor=custom_logit_processor,
194
            return_hidden_states=return_hidden_states,
195
            stream=stream,
196
197
198
            bootstrap_host=bootstrap_host,
            bootstrap_port=bootstrap_port,
            bootstrap_room=bootstrap_room,
199
            data_parallel_rank=data_parallel_rank,
200
201
        )
        loop = asyncio.get_event_loop()
202
        generator = self.tokenizer_manager.generate_request(obj, None)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        if stream:

            def generator_wrapper():
                while True:
                    try:
                        chunk = loop.run_until_complete(generator.__anext__())
                        yield chunk
                    except StopAsyncIteration:
                        break

            return generator_wrapper()
        else:
            ret = loop.run_until_complete(generator.__anext__())
            return ret

    async def async_generate(
        self,
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
        sampling_params: Optional[Union[List[Dict], Dict]] = None,
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
226
227
228
229
230
231
        # The image input. It can be an image instance, file name, URL, or base64 encoded string.
        # Can be formatted as:
        # - Single image for a single request
        # - List of images (one per request in a batch)
        # - List of lists of images (multiple images per request)
        # See also python/sglang/srt/utils.py:load_image for more details.
232
233
234
        image_data: Optional[MultimodalDataInputFormat] = None,
        audio_data: Optional[MultimodalDataInputFormat] = None,
        video_data: Optional[MultimodalDataInputFormat] = None,
235
236
237
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
238
        token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
239
240
        lora_path: Optional[List[Optional[str]]] = None,
        custom_logit_processor: Optional[Union[List[str], str]] = None,
241
        return_hidden_states: bool = False,
242
        stream: bool = False,
243
244
245
        bootstrap_host: Optional[Union[List[str], str]] = None,
        bootstrap_port: Optional[Union[List[int], int]] = None,
        bootstrap_room: Optional[Union[List[int], int]] = None,
246
        data_parallel_rank: Optional[int] = None,
247
248
249
250
251
    ) -> Union[Dict, AsyncIterator[Dict]]:
        """
        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
        Please refer to `GenerateReqInput` for the documentation.
        """
252
253
254

        if self.server_args.enable_dp_attention:
            if data_parallel_rank is None:
255
                logger.debug("data_parallel_rank not provided, using default dispatch")
256
257
258
259
260
261
262
            elif data_parallel_rank < 0:
                raise ValueError("data_parallel_rank must be non-negative")
            elif data_parallel_rank >= self.server_args.dp_size:
                raise ValueError(
                    f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
                )

263
        logger.debug(f"data_parallel_rank: {data_parallel_rank}")
264
265
266
267
        obj = GenerateReqInput(
            text=prompt,
            input_ids=input_ids,
            sampling_params=sampling_params,
268
            image_data=image_data,
269
270
            audio_data=audio_data,
            video_data=video_data,
271
272
273
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
274
            token_ids_logprob=token_ids_logprob,
275
            lora_path=lora_path,
276
            return_hidden_states=return_hidden_states,
277
278
            stream=stream,
            custom_logit_processor=custom_logit_processor,
279
280
281
            bootstrap_host=bootstrap_host,
            bootstrap_port=bootstrap_port,
            bootstrap_room=bootstrap_room,
282
            data_parallel_rank=data_parallel_rank,
283
        )
284
        generator = self.tokenizer_manager.generate_request(obj, None)
285
286
287
288
289
290
291
292
293

        if stream is True:
            return generator
        else:
            return await generator.__anext__()

    def encode(
        self,
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
294
295
296
        image_data: Optional[MultimodalDataInputFormat] = None,
        audio_data: Optional[MultimodalDataInputFormat] = None,
        video_data: Optional[MultimodalDataInputFormat] = None,
297
298
299
300
301
    ) -> Dict:
        """
        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
        Please refer to `EmbeddingReqInput` for the documentation.
        """
302
303
304
305
306
307
        obj = EmbeddingReqInput(
            text=prompt,
            image_data=image_data,
            audio_data=audio_data,
            video_data=video_data,
        )
308
        loop = asyncio.get_event_loop()
309
        generator = self.tokenizer_manager.generate_request(obj, None)
310
311
312
        ret = loop.run_until_complete(generator.__anext__())
        return ret

313
314
315
    async def async_encode(
        self,
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
316
317
318
        image_data: Optional[MultimodalDataInputFormat] = None,
        audio_data: Optional[MultimodalDataInputFormat] = None,
        video_data: Optional[MultimodalDataInputFormat] = None,
319
320
321
322
323
324
325
    ) -> Dict:
        """
        Asynchronous version of encode method.

        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
        Please refer to `EmbeddingReqInput` for the documentation.
        """
326
327
328
329
330
331
        obj = EmbeddingReqInput(
            text=prompt,
            image_data=image_data,
            audio_data=audio_data,
            video_data=video_data,
        )
332
333
334
        generator = self.tokenizer_manager.generate_request(obj, None)
        return await generator.__anext__()

woodx's avatar
woodx committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    def rerank(
        self,
        prompt: Union[List[List[str]]],
    ) -> Dict:
        """
        The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
        Please refer to `EmbeddingReqInput` for the documentation.
        """
        obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
        loop = asyncio.get_event_loop()
        generator = self.tokenizer_manager.generate_request(obj, None)
        ret = loop.run_until_complete(generator.__anext__())
        return ret

349
350
351
352
    def shutdown(self):
        """Shutdown the engine"""
        kill_process_tree(os.getpid(), include_parent=False)

353
354
355
356
357
358
359
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.shutdown()
        return False

360
361
362
363
    def flush_cache(self):
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(self.tokenizer_manager.flush_cache())

364
    def start_profile(self):
365
366
        loop = asyncio.get_event_loop()
        loop.run_until_complete(self.tokenizer_manager.start_profile())
367
368

    def stop_profile(self):
369
370
        loop = asyncio.get_event_loop()
        loop.run_until_complete(self.tokenizer_manager.stop_profile())
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    def start_expert_distribution_record(self):
        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            self.tokenizer_manager.start_expert_distribution_record()
        )

    def stop_expert_distribution_record(self):
        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            self.tokenizer_manager.stop_expert_distribution_record()
        )

    def dump_expert_distribution_record(self):
        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            self.tokenizer_manager.dump_expert_distribution_record()
        )

390
    def get_server_info(self):
391
392
393
394
        loop = asyncio.get_event_loop()
        internal_states = loop.run_until_complete(
            self.tokenizer_manager.get_internal_state()
        )
395
        return {
396
            **dataclasses.asdict(self.tokenizer_manager.server_args),
397
            **self.scheduler_info,
398
            "internal_states": internal_states,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            "version": __version__,
        }

    def init_weights_update_group(
        self,
        master_address: str,
        master_port: int,
        rank_offset: int,
        world_size: int,
        group_name: str,
        backend: str = "nccl",
    ):
        """Initialize parameter update group."""
        obj = InitWeightsUpdateGroupReqInput(
            master_address=master_address,
            master_port=master_port,
            rank_offset=rank_offset,
            world_size=world_size,
            group_name=group_name,
            backend=backend,
        )
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
422
            self.tokenizer_manager.init_weights_update_group(obj, None)
423
424
        )

425
426
427
428
429
430
431
432
    def update_weights_from_distributed(
        self,
        names: list[str],
        dtypes: list[str],
        shapes: list[list[int]],
        group_name: str = "weight_update_group",
        flush_cache: bool = True,
    ):
433
434
        """Update weights from distributed source."""
        obj = UpdateWeightsFromDistributedReqInput(
435
436
437
438
439
            names=names,
            dtypes=dtypes,
            shapes=shapes,
            group_name=group_name,
            flush_cache=flush_cache,
440
441
442
        )
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
443
            self.tokenizer_manager.update_weights_from_distributed(obj, None)
444
445
        )

446
447
448
449
450
451
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, torch.Tensor]],
        load_format: Optional[str] = None,
        flush_cache: bool = True,
    ):
fzyzcjy's avatar
fzyzcjy committed
452
453
        """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
        to avoid duplicated cache cleaning operation."""
454
455
456
457
        if load_format == "flattened_bucket":
            serialized_named_tensors = named_tensors
        else:
            serialized_named_tensors = [
458
459
                MultiprocessingSerializer.serialize(named_tensors)
                for _ in range(self.server_args.tp_size)
460
461
462
            ]
        obj = UpdateWeightsFromTensorReqInput(
            serialized_named_tensors=serialized_named_tensors,
463
464
            load_format=load_format,
            flush_cache=flush_cache,
465
466
        )
        loop = asyncio.get_event_loop()
467

468
        return loop.run_until_complete(
469
            self.tokenizer_manager.update_weights_from_tensor(obj, None)
470
471
        )

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    def update_weights_from_disk(
        self,
        model_path: str,
        load_format: Optional[str] = None,
    ):
        """Update the weights from disk inplace without re-launching the engine.

        This method allows updating the model weights from disk without restarting
        the engine. It can be used to load a different model or update weights with
        new training.
        """
        obj = UpdateWeightFromDiskReqInput(
            model_path=model_path,
            load_format=load_format,
        )

        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
            self.tokenizer_manager.update_weights_from_disk(obj, None)
        )

493
494
495
496
    def get_weights_by_name(self, name: str, truncate_size: int = 100):
        """Get weights by parameter name."""
        obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
        loop = asyncio.get_event_loop()
497
498
499
        return loop.run_until_complete(
            self.tokenizer_manager.get_weights_by_name(obj, None)
        )
500

501
    def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
502
503
504
505
506
        """Load a new LoRA adapter without re-launching the engine."""

        obj = LoadLoRAAdapterReqInput(
            lora_name=lora_name,
            lora_path=lora_path,
507
            pinned=pinned,
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        )

        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
            self.tokenizer_manager.load_lora_adapter(obj, None)
        )

    def unload_lora_adapter(self, lora_name: str):
        """Unload a LoRA adapter without re-launching the engine."""

        obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)

        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
            self.tokenizer_manager.unload_lora_adapter(obj, None)
        )

525
526
    def release_memory_occupation(self, tags: Optional[List[str]] = None):
        obj = ReleaseMemoryOccupationReqInput(tags=tags)
527
528
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
529
            self.tokenizer_manager.release_memory_occupation(obj, None)
530
531
        )

532
533
    def resume_memory_occupation(self, tags: Optional[List[str]] = None):
        obj = ResumeMemoryOccupationReqInput(tags=tags)
534
535
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
536
            self.tokenizer_manager.resume_memory_occupation(obj, None)
537
538
        )

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    """
    Execute an RPC call on all scheduler processes.
    """

    def collective_rpc(self, method: str, **kwargs):
        obj = RpcReqInput(method=method, parameters=kwargs)
        self.send_to_rpc.send_pyobj(obj)
        recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
        assert isinstance(recv_req, RpcReqOutput)
        assert recv_req.success, recv_req.message

    def save_remote_model(self, **kwargs):
        self.collective_rpc("save_remote_model", **kwargs)

    def save_sharded_model(self, **kwargs):
        self.collective_rpc("save_sharded_model", **kwargs)

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    def score(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
    ) -> List[List[float]]:
        """
        Score the probability of specified token IDs appearing after the given (query + item) pair. For example:
        query = "<|user|>Is the following city the capital of France? "
        items = ["Paris <|assistant|>", "London <|assistant|>", "Berlin <|assistant|>"]
        label_token_ids = [2332, 1223] # Token IDs for "Yes" and "No"
        item_first = False

        This would pass the following prompts to the model:
        "<|user|>Is the following city the capital of France? Paris <|assistant|>"
        "<|user|>Is the following city the capital of France? London <|assistant|>"
        "<|user|>Is the following city the capital of France? Berlin <|assistant|>"
        The api would then return the probabilities of the model producing "Yes" and "No" as the next token.
        The output would look like:
        [[0.9, 0.1], [0.2, 0.8], [0.1, 0.9]]


        Args:
            query: The query text or pre-tokenized query token IDs. Must be provided.
            items: The item text(s) or pre-tokenized item token IDs. Must be provided.
            label_token_ids: List of token IDs to compute probabilities for. If None, no token probabilities will be computed.
            apply_softmax: Whether to normalize probabilities using softmax.
            item_first: If True, prepend items to query. Otherwise append items to query.

        Returns:
            List of dictionaries mapping token IDs to their probabilities for each item.
            Each dictionary in the list corresponds to one item input.

        Raises:
            ValueError: If query is not provided, or if items is not provided,
                      or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens.
        """
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
            self.tokenizer_manager.score_request(
                query=query,
                items=items,
                label_token_ids=label_token_ids,
                apply_softmax=apply_softmax,
                item_first=item_first,
                request=None,
            )
        )

    async def async_score(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
    ) -> List[List[float]]:
        """
        Asynchronous version of score method.

        See score() for detailed documentation.
        """
        return await self.tokenizer_manager.score_request(
            query=query,
            items=items,
            label_token_ids=label_token_ids,
            apply_softmax=apply_softmax,
            item_first=item_first,
            request=None,
        )

629
630
631
632

def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
633
634
635
    os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
    if not server_args.enable_symm_mem:
        os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
636
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
637
    os.environ["CUDA_MODULE_LOADING"] = "AUTO"
638
639
    # flashinfer uses this environment variable for various kernels from MoE to quant kernels
    os.environ["TRTLLM_ENABLE_PDL"] = "1"
640
641
642
643
644
645
646
647
648
649
650

    # Set prometheus env vars
    if server_args.enable_metrics:
        set_prometheus_multiproc_dir()

    # Set ulimit
    set_ulimit()

    # Check flashinfer version
    if server_args.attention_backend == "flashinfer":
        assert_pkg_version(
651
            "flashinfer_python",
eigen's avatar
eigen committed
652
            "0.2.11.post3",
653
654
655
656
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
657
    if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
658
659
        assert_pkg_version(
            "sgl-kernel",
660
            "0.3.5",
661
662
            "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
        )
663

664
665
666
667
668
669
670
671
672
    if True:  # Keep this check for internal code compatibility
        # Register the signal handler.
        # The child processes will send SIGQUIT to this process when any error happens
        # This process then clean up the whole process tree
        # Note: This sigquit handler is used in the launch phase, and may be replaced by
        # the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
        def launch_phase_sigquit_handler(signum, frame):
            logger.error(
                "Received sigquit from a child process. It usually means the child failed."
673
            )
674
            kill_process_tree(os.getpid())
675

676
        signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)
677
678
679
680
681

    # Set mp start method
    mp.set_start_method("spawn", force=True)


682
683
def _launch_subprocesses(
    server_args: ServerArgs, port_args: Optional[PortArgs] = None
684
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
685
    """
686
    Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
687
688
689
690
691
692
693
    """
    # Configure global environment
    configure_logger(server_args)
    server_args.check_server_args()
    _set_envs_and_config(server_args)

    # Allocate ports for inter-process communications
694
695
696
    if port_args is None:
        port_args = PortArgs.init_new(server_args)
        logger.info(f"{server_args=}")
697
698
699
700
701
702
703
704
705
706
707
708

    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )

    scheduler_procs = []
    if server_args.dp_size == 1:
        memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
        scheduler_pipe_readers = []
709
710
711

        nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
        tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
712
        tp_rank_range = range(
713
714
            tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
            tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
715
        )
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730

        pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
        pp_rank_range = range(
            pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
            pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
        )

        for pp_rank in pp_rank_range:
            for tp_rank in tp_rank_range:
                reader, writer = mp.Pipe(duplex=False)
                gpu_id = (
                    server_args.base_gpu_id
                    + ((pp_rank % pp_size_per_node) * tp_size_per_node)
                    + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
                )
Cheng Wan's avatar
Cheng Wan committed
731
                moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
732
733
734
735
736
737
738
                proc = mp.Process(
                    target=run_scheduler_process,
                    args=(
                        server_args,
                        port_args,
                        gpu_id,
                        tp_rank,
Cheng Wan's avatar
Cheng Wan committed
739
                        moe_ep_rank,
740
741
742
                        pp_rank,
                        None,
                        writer,
743
                        None,
744
745
                    ),
                )
746

747
748
749
750
                with memory_saver_adapter.configure_subprocess():
                    proc.start()
                scheduler_procs.append(proc)
                scheduler_pipe_readers.append(reader)
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
    else:
        # Launch the data parallel controller
        reader, writer = mp.Pipe(duplex=False)
        scheduler_pipe_readers = [reader]
        proc = mp.Process(
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
        )
        proc.start()
        scheduler_procs.append(proc)

    if server_args.node_rank >= 1:
        # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
        # so they can just wait here.

        for reader in scheduler_pipe_readers:
            data = reader.recv()
            assert data["status"] == "ready"

        if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
            # When using `Engine` as a Python API, we don't want to block here.
772
            return None, None, None
773

774
775
776
        launch_dummy_health_check_server(
            server_args.host, server_args.port, server_args.enable_metrics
        )
777
778
779
780
781
782

        for proc in scheduler_procs:
            proc.join()
            logger.error(
                f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
            )
783
        return None, None, None
784
785
786
787
788
789
790
791
792
793
794
795

    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
        args=(
            server_args,
            port_args,
        ),
    )
    detoken_proc.start()

    # Launch tokenizer process
796
    tokenizer_manager = TokenizerManager(server_args, port_args)
797

798
799
800
801
802
803
804
805
    # Initialize templates
    template_manager = TemplateManager()
    template_manager.initialize_templates(
        tokenizer_manager=tokenizer_manager,
        model_path=server_args.model_path,
        chat_template=server_args.chat_template,
        completion_template=server_args.completion_template,
    )
806

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
    # Wait for the model to finish loading
    scheduler_infos = []
    for i in range(len(scheduler_pipe_readers)):
        try:
            data = scheduler_pipe_readers[i].recv()
        except EOFError:
            logger.error(
                f"Rank {i} scheduler is dead. Please check if there are relevant logs."
            )
            scheduler_procs[i].join()
            logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
            raise

        if data["status"] != "ready":
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
        scheduler_infos.append(data)

    # Assume all schedulers have the same scheduler_info
    scheduler_info = scheduler_infos[0]
828
    tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
829
    return tokenizer_manager, template_manager, scheduler_info